Implement a Java/Kotlin client.

This commit is contained in:
2023-04-22 23:24:36 -07:00
parent 0b5f5dae57
commit 9b0c174df4
10 changed files with 195 additions and 11 deletions

View File

@ -1,4 +1,3 @@
name: macOS
on: [push]
jobs:
build:
@ -8,13 +7,20 @@ jobs:
uses: actions/checkout@v3
- name: Build Executable
run: swift build -c release --arch arm64 --arch x86_64
- name: Copy Executable
run: cp .build/apple/Products/Release/StableDiffusionServer StableDiffusionServer
- name: Archive Executable
- name: Copy Server Executable
run: cp .build/apple/Products/Release/stable-diffusion-rpc stable-diffusion-rpc
- name: Archive Server Executable
uses: actions/upload-artifact@v2
with:
name: StableDiffusionServer
path: StableDiffusionServer
name: stable-diffusion-rpc
path: stable-diffusion-rpc
- name: Copy Control Executable
run: cp .build/apple/Products/Release/stable-diffusion-ctl stable-diffusion-ctl
- name: Archive Control Executable
uses: actions/upload-artifact@v2
with:
name: stable-diffusion-rpc
path: stable-diffusion-rpc
format:
runs-on: macos-12
steps:

17
.github/workflows/java.yml vendored Normal file
View File

@ -0,0 +1,17 @@
on: [push]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Set up JDK 17
uses: actions/setup-java@v3
with:
java-version: '17'
distribution: 'temurin'
- name: Build with Gradle
uses: gradle/gradle-build-action@v2
with:
arguments: build
build-root-directory: Clients/Java

View File

@ -1,15 +1,17 @@
import com.google.protobuf.gradle.id
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
group = "gay.pizza.stable.diffusion"
version = "1.0.0-SNAPSHOT"
plugins {
application
`java-library`
`maven-publish`
kotlin("jvm") version "1.8.20"
kotlin("plugin.serialization") version "1.8.20"
`maven-publish`
id("com.google.protobuf") version "0.9.2"
}
repositories {
@ -22,9 +24,53 @@ java {
targetCompatibility = javaVersion
}
sourceSets {
main {
proto {
srcDir("../../Common")
include("*.proto")
}
}
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
api("io.grpc:grpc-stub:1.54.1")
api("io.grpc:grpc-protobuf:1.54.1")
api("io.grpc:grpc-kotlin-stub:1.3.0")
implementation("io.grpc:grpc-netty:1.54.1")
}
protobuf {
protoc {
artifact = "com.google.protobuf:protoc:3.22.3"
}
plugins {
create("grpc") {
artifact = "io.grpc:protoc-gen-grpc-java:1.54.1"
}
create("grpckt") {
artifact = "io.grpc:protoc-gen-grpc-kotlin:1.3.0:jdk8@jar"
}
}
generateProtoTasks {
all().configureEach {
builtins {
java {}
kotlin {}
}
plugins {
id("grpc") {}
id("grpckt") {}
}
}
}
}
publishing {

View File

@ -0,0 +1,27 @@
plugins {
application
kotlin("jvm")
kotlin("plugin.serialization")
}
repositories {
mavenCentral()
}
dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation(rootProject)
}
java {
val javaVersion = JavaVersion.toVersion(17)
sourceCompatibility = javaVersion
targetCompatibility = javaVersion
}
application {
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
}

View File

@ -0,0 +1,49 @@
package gay.pizza.stable.diffusion.sample
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder
import kotlin.system.exitProcess
fun main() {
val channel = ManagedChannelBuilder
.forAddress("127.0.0.1", 4546)
.usePlaintext()
.build()
val client = StableDiffusionRpcClient(channel)
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
if (modelListResponse.modelsList.isEmpty()) {
println("no available models")
exitProcess(0)
}
println("available models:")
for (model in modelListResponse.modelsList) {
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
}
val model = modelListResponse.modelsList.random()
if (!model.isLoaded) {
println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
}.build())
} else {
println("using model ${model.name}...")
}
println("generating images...")
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
modelName = model.name
imageCount = 1
prompt = "cat"
negativePrompt = "bad, low quality, nsfw"
}.build())
println("generated ${generateImagesResponse.imagesCount} images")
channel.shutdownNow()
}

View File

@ -1 +1,2 @@
rootProject.name = "stable-diffusion-rpc"
include("sample")

View File

@ -0,0 +1,38 @@
package gay.pizza.stable.diffusion
import io.grpc.Channel
@Suppress("MemberVisibilityCanBePrivate")
class StableDiffusionRpcClient(val channel: Channel) {
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
ModelServiceGrpc.newStub(channel)
}
val modelServiceBlocking: ModelServiceGrpc.ModelServiceBlockingStub by lazy {
ModelServiceGrpc.newBlockingStub(channel)
}
val modelServiceFuture: ModelServiceGrpc.ModelServiceFutureStub by lazy {
ModelServiceGrpc.newFutureStub(channel)
}
val modelServiceCoroutine: ModelServiceGrpcKt.ModelServiceCoroutineStub by lazy {
ModelServiceGrpcKt.ModelServiceCoroutineStub(channel)
}
val imageGenerationService: ImageGenerationServiceGrpc.ImageGenerationServiceStub by lazy {
ImageGenerationServiceGrpc.newStub(channel)
}
val imageGenerationServiceBlocking: ImageGenerationServiceGrpc.ImageGenerationServiceBlockingStub by lazy {
ImageGenerationServiceGrpc.newBlockingStub(channel)
}
val imageGenerationServiceFuture: ImageGenerationServiceGrpc.ImageGenerationServiceFutureStub by lazy {
ImageGenerationServiceGrpc.newFutureStub(channel)
}
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
}
}

View File

@ -5,9 +5,9 @@ let package = Package(
name: "stable-diffusion-rpc",
platforms: [.macOS("13.1"), .iOS("16.2")],
products: [
.executable(name: "StableDiffusionServer", targets: ["StableDiffusionServer"]),
.executable(name: "stable-diffusion-rpc", targets: ["StableDiffusionServer"]),
.library(name: "StableDiffusionProtos", targets: ["StableDiffusionProtos"]),
.executable(name: "TestStableDiffusionClient", targets: ["TestStableDiffusionClient"])
.executable(name: "stable-diffusion-ctl", targets: ["StableDiffusionControl"])
],
dependencies: [
.package(url: "https://github.com/apple/ml-stable-diffusion", revision: "5d2744e38297b01662b8bdfb41e899ac98036d8b"),
@ -32,7 +32,7 @@ let package = Package(
.target(name: "StableDiffusionCore"),
.product(name: "ArgumentParser", package: "swift-argument-parser")
]),
.executableTarget(name: "TestStableDiffusionClient", dependencies: [
.executableTarget(name: "StableDiffusionControl", dependencies: [
.target(name: "StableDiffusionProtos"),
.target(name: "StableDiffusionCore"),
.product(name: "GRPC", package: "grpc-swift")