Document API, make the implementation match the API, and update the same.

This commit is contained in:
2023-04-23 02:09:50 -07:00
parent 71afed326f
commit 7c0b2779f4
15 changed files with 707 additions and 310 deletions

View File

@ -26,15 +26,6 @@ java {
withSourcesJar()
}
sourceSets {
main {
proto {
srcDir("../../Common")
include("*.proto")
}
}
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")

View File

@ -1,10 +1,13 @@
package gay.pizza.stable.diffusion.sample
import gay.pizza.stable.diffusion.StableDiffusion
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder
import kotlin.io.path.Path
import kotlin.io.path.writeBytes
import kotlin.system.exitProcess
fun main() {
@ -15,20 +18,22 @@ fun main() {
val client = StableDiffusionRpcClient(channel)
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
if (modelListResponse.modelsList.isEmpty()) {
if (modelListResponse.availableModelsList.isEmpty()) {
println("no available models")
exitProcess(0)
}
println("available models:")
for (model in modelListResponse.modelsList) {
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
for (model in modelListResponse.availableModelsList) {
val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else ""
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
}
val model = modelListResponse.modelsList.random()
val model = modelListResponse.availableModelsList.random()
if (!model.isLoaded) {
println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
computeUnits = model.supportedComputeUnitsList.first()
}.build())
} else {
println("using model ${model.name}...")
@ -36,14 +41,22 @@ fun main() {
println("generating images...")
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
val request = GenerateImagesRequest.newBuilder().apply {
modelName = model.name
imageCount = 1
outputImageFormat = StableDiffusion.ImageFormat.png
batchSize = 2
batchCount = 2
prompt = "cat"
negativePrompt = "bad, low quality, nsfw"
}.build())
}.build()
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
println("generated ${generateImagesResponse.imagesCount} images")
println("generated ${generateImagesResponse.imagesCount} images:")
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${index}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
}
channel.shutdownNow()
}

1
Clients/Java/src/main/proto Symbolic link
View File

@ -0,0 +1 @@
../../../../Common