Implement a Java/Kotlin client.

This commit is contained in:
2023-04-22 23:24:36 -07:00
parent 0b5f5dae57
commit 9b0c174df4
10 changed files with 195 additions and 11 deletions

View File

@ -0,0 +1,27 @@
plugins {
application
kotlin("jvm")
kotlin("plugin.serialization")
}
repositories {
mavenCentral()
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation(rootProject)
}
java {
val javaVersion = JavaVersion.toVersion(17)
sourceCompatibility = javaVersion
targetCompatibility = javaVersion
}
application {
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
}

View File

@ -0,0 +1,49 @@
package gay.pizza.stable.diffusion.sample
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder
import kotlin.system.exitProcess
fun main() {
val channel = ManagedChannelBuilder
.forAddress("127.0.0.1", 4546)
.usePlaintext()
.build()
val client = StableDiffusionRpcClient(channel)
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
if (modelListResponse.modelsList.isEmpty()) {
println("no available models")
exitProcess(0)
}
println("available models:")
for (model in modelListResponse.modelsList) {
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
}
val model = modelListResponse.modelsList.random()
if (!model.isLoaded) {
println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
}.build())
} else {
println("using model ${model.name}...")
}
println("generating images...")
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
modelName = model.name
imageCount = 1
prompt = "cat"
negativePrompt = "bad, low quality, nsfw"
}.build())
println("generated ${generateImagesResponse.imagesCount} images")
channel.shutdownNow()
}