Implement a Java/Kotlin client.

This commit is contained in:
2023-04-22 23:24:36 -07:00
parent 0b5f5dae57
commit 9b0c174df4
10 changed files with 195 additions and 11 deletions

View File

@ -1,15 +1,17 @@
import com.google.protobuf.gradle.id
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
group = "gay.pizza.stable.diffusion"
version = "1.0.0-SNAPSHOT"
plugins {
application
`java-library`
`maven-publish`
kotlin("jvm") version "1.8.20"
kotlin("plugin.serialization") version "1.8.20"
`maven-publish`
id("com.google.protobuf") version "0.9.2"
}
repositories {
@ -22,9 +24,53 @@ java {
targetCompatibility = javaVersion
}
sourceSets {
main {
proto {
srcDir("../../Common")
include("*.proto")
}
}
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
api("io.grpc:grpc-stub:1.54.1")
api("io.grpc:grpc-protobuf:1.54.1")
api("io.grpc:grpc-kotlin-stub:1.3.0")
implementation("io.grpc:grpc-netty:1.54.1")
}
protobuf {
protoc {
artifact = "com.google.protobuf:protoc:3.22.3"
}
plugins {
create("grpc") {
artifact = "io.grpc:protoc-gen-grpc-java:1.54.1"
}
create("grpckt") {
artifact = "io.grpc:protoc-gen-grpc-kotlin:1.3.0:jdk8@jar"
}
}
generateProtoTasks {
all().configureEach {
builtins {
java {}
kotlin {}
}
plugins {
id("grpc") {}
id("grpckt") {}
}
}
}
}
publishing {

View File

@ -0,0 +1,27 @@
plugins {
application
kotlin("jvm")
kotlin("plugin.serialization")
}
repositories {
mavenCentral()
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation(rootProject)
}
java {
val javaVersion = JavaVersion.toVersion(17)
sourceCompatibility = javaVersion
targetCompatibility = javaVersion
}
application {
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
}

View File

@ -0,0 +1,49 @@
package gay.pizza.stable.diffusion.sample
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder
import kotlin.system.exitProcess
fun main() {
val channel = ManagedChannelBuilder
.forAddress("127.0.0.1", 4546)
.usePlaintext()
.build()
val client = StableDiffusionRpcClient(channel)
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
if (modelListResponse.modelsList.isEmpty()) {
println("no available models")
exitProcess(0)
}
println("available models:")
for (model in modelListResponse.modelsList) {
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
}
val model = modelListResponse.modelsList.random()
if (!model.isLoaded) {
println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
}.build())
} else {
println("using model ${model.name}...")
}
println("generating images...")
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
modelName = model.name
imageCount = 1
prompt = "cat"
negativePrompt = "bad, low quality, nsfw"
}.build())
println("generated ${generateImagesResponse.imagesCount} images")
channel.shutdownNow()
}

View File

@ -1 +1,2 @@
rootProject.name = "stable-diffusion-rpc"
include("sample")

View File

@ -0,0 +1,38 @@
package gay.pizza.stable.diffusion
import io.grpc.Channel
@Suppress("MemberVisibilityCanBePrivate")
class StableDiffusionRpcClient(val channel: Channel) {
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
ModelServiceGrpc.newStub(channel)
}
val modelServiceBlocking: ModelServiceGrpc.ModelServiceBlockingStub by lazy {
ModelServiceGrpc.newBlockingStub(channel)
}
val modelServiceFuture: ModelServiceGrpc.ModelServiceFutureStub by lazy {
ModelServiceGrpc.newFutureStub(channel)
}
val modelServiceCoroutine: ModelServiceGrpcKt.ModelServiceCoroutineStub by lazy {
ModelServiceGrpcKt.ModelServiceCoroutineStub(channel)
}
val imageGenerationService: ImageGenerationServiceGrpc.ImageGenerationServiceStub by lazy {
ImageGenerationServiceGrpc.newStub(channel)
}
val imageGenerationServiceBlocking: ImageGenerationServiceGrpc.ImageGenerationServiceBlockingStub by lazy {
ImageGenerationServiceGrpc.newBlockingStub(channel)
}
val imageGenerationServiceFuture: ImageGenerationServiceGrpc.ImageGenerationServiceFutureStub by lazy {
ImageGenerationServiceGrpc.newFutureStub(channel)
}
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
}
}