mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 05:30:54 +00:00
Implement a Java/Kotlin client.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
name: macOS
|
||||
on: [push]
|
||||
jobs:
|
||||
build:
|
||||
@ -8,13 +7,20 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
- name: Build Executable
|
||||
run: swift build -c release --arch arm64 --arch x86_64
|
||||
- name: Copy Executable
|
||||
run: cp .build/apple/Products/Release/StableDiffusionServer StableDiffusionServer
|
||||
- name: Archive Executable
|
||||
- name: Copy Server Executable
|
||||
run: cp .build/apple/Products/Release/stable-diffusion-rpc stable-diffusion-rpc
|
||||
- name: Archive Server Executable
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: StableDiffusionServer
|
||||
path: StableDiffusionServer
|
||||
name: stable-diffusion-rpc
|
||||
path: stable-diffusion-rpc
|
||||
- name: Copy Control Executable
|
||||
run: cp .build/apple/Products/Release/stable-diffusion-ctl stable-diffusion-ctl
|
||||
- name: Archive Control Executable
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: stable-diffusion-rpc
|
||||
path: stable-diffusion-rpc
|
||||
format:
|
||||
runs-on: macos-12
|
||||
steps:
|
17
.github/workflows/java.yml
vendored
Normal file
17
.github/workflows/java.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
on: [push]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up JDK 17
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '17'
|
||||
distribution: 'temurin'
|
||||
- name: Build with Gradle
|
||||
uses: gradle/gradle-build-action@v2
|
||||
with:
|
||||
arguments: build
|
||||
build-root-directory: Clients/Java
|
@ -1,15 +1,17 @@
|
||||
import com.google.protobuf.gradle.id
|
||||
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||
|
||||
group = "gay.pizza.stable.diffusion"
|
||||
version = "1.0.0-SNAPSHOT"
|
||||
|
||||
plugins {
|
||||
application
|
||||
`java-library`
|
||||
`maven-publish`
|
||||
|
||||
kotlin("jvm") version "1.8.20"
|
||||
kotlin("plugin.serialization") version "1.8.20"
|
||||
|
||||
`maven-publish`
|
||||
id("com.google.protobuf") version "0.9.2"
|
||||
}
|
||||
|
||||
repositories {
|
||||
@ -22,9 +24,53 @@ java {
|
||||
targetCompatibility = javaVersion
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
proto {
|
||||
srcDir("../../Common")
|
||||
include("*.proto")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||
|
||||
api("io.grpc:grpc-stub:1.54.1")
|
||||
api("io.grpc:grpc-protobuf:1.54.1")
|
||||
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
||||
implementation("io.grpc:grpc-netty:1.54.1")
|
||||
}
|
||||
|
||||
protobuf {
|
||||
protoc {
|
||||
artifact = "com.google.protobuf:protoc:3.22.3"
|
||||
}
|
||||
|
||||
plugins {
|
||||
create("grpc") {
|
||||
artifact = "io.grpc:protoc-gen-grpc-java:1.54.1"
|
||||
}
|
||||
|
||||
create("grpckt") {
|
||||
artifact = "io.grpc:protoc-gen-grpc-kotlin:1.3.0:jdk8@jar"
|
||||
}
|
||||
}
|
||||
|
||||
generateProtoTasks {
|
||||
all().configureEach {
|
||||
builtins {
|
||||
java {}
|
||||
kotlin {}
|
||||
}
|
||||
|
||||
plugins {
|
||||
id("grpc") {}
|
||||
id("grpckt") {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
publishing {
|
||||
|
27
Clients/Java/sample/build.gradle.kts
Normal file
27
Clients/Java/sample/build.gradle.kts
Normal file
@ -0,0 +1,27 @@
|
||||
plugins {
|
||||
application
|
||||
|
||||
kotlin("jvm")
|
||||
kotlin("plugin.serialization")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||
|
||||
implementation(rootProject)
|
||||
}
|
||||
|
||||
java {
|
||||
val javaVersion = JavaVersion.toVersion(17)
|
||||
sourceCompatibility = javaVersion
|
||||
targetCompatibility = javaVersion
|
||||
}
|
||||
|
||||
application {
|
||||
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package gay.pizza.stable.diffusion.sample
|
||||
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
fun main() {
|
||||
val channel = ManagedChannelBuilder
|
||||
.forAddress("127.0.0.1", 4546)
|
||||
.usePlaintext()
|
||||
.build()
|
||||
|
||||
val client = StableDiffusionRpcClient(channel)
|
||||
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
|
||||
if (modelListResponse.modelsList.isEmpty()) {
|
||||
println("no available models")
|
||||
exitProcess(0)
|
||||
}
|
||||
println("available models:")
|
||||
for (model in modelListResponse.modelsList) {
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
|
||||
}
|
||||
|
||||
val model = modelListResponse.modelsList.random()
|
||||
if (!model.isLoaded) {
|
||||
println("loading model ${model.name}...")
|
||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
}.build())
|
||||
} else {
|
||||
println("using model ${model.name}...")
|
||||
}
|
||||
|
||||
println("generating images...")
|
||||
|
||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
imageCount = 1
|
||||
prompt = "cat"
|
||||
negativePrompt = "bad, low quality, nsfw"
|
||||
}.build())
|
||||
|
||||
println("generated ${generateImagesResponse.imagesCount} images")
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
@ -1 +1,2 @@
|
||||
rootProject.name = "stable-diffusion-rpc"
|
||||
include("sample")
|
||||
|
@ -0,0 +1,38 @@
|
||||
package gay.pizza.stable.diffusion
|
||||
|
||||
import io.grpc.Channel
|
||||
|
||||
@Suppress("MemberVisibilityCanBePrivate")
|
||||
class StableDiffusionRpcClient(val channel: Channel) {
|
||||
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
|
||||
ModelServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
||||
val modelServiceBlocking: ModelServiceGrpc.ModelServiceBlockingStub by lazy {
|
||||
ModelServiceGrpc.newBlockingStub(channel)
|
||||
}
|
||||
|
||||
val modelServiceFuture: ModelServiceGrpc.ModelServiceFutureStub by lazy {
|
||||
ModelServiceGrpc.newFutureStub(channel)
|
||||
}
|
||||
|
||||
val modelServiceCoroutine: ModelServiceGrpcKt.ModelServiceCoroutineStub by lazy {
|
||||
ModelServiceGrpcKt.ModelServiceCoroutineStub(channel)
|
||||
}
|
||||
|
||||
val imageGenerationService: ImageGenerationServiceGrpc.ImageGenerationServiceStub by lazy {
|
||||
ImageGenerationServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
||||
val imageGenerationServiceBlocking: ImageGenerationServiceGrpc.ImageGenerationServiceBlockingStub by lazy {
|
||||
ImageGenerationServiceGrpc.newBlockingStub(channel)
|
||||
}
|
||||
|
||||
val imageGenerationServiceFuture: ImageGenerationServiceGrpc.ImageGenerationServiceFutureStub by lazy {
|
||||
ImageGenerationServiceGrpc.newFutureStub(channel)
|
||||
}
|
||||
|
||||
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
|
||||
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
|
||||
}
|
||||
}
|
@ -5,9 +5,9 @@ let package = Package(
|
||||
name: "stable-diffusion-rpc",
|
||||
platforms: [.macOS("13.1"), .iOS("16.2")],
|
||||
products: [
|
||||
.executable(name: "StableDiffusionServer", targets: ["StableDiffusionServer"]),
|
||||
.executable(name: "stable-diffusion-rpc", targets: ["StableDiffusionServer"]),
|
||||
.library(name: "StableDiffusionProtos", targets: ["StableDiffusionProtos"]),
|
||||
.executable(name: "TestStableDiffusionClient", targets: ["TestStableDiffusionClient"])
|
||||
.executable(name: "stable-diffusion-ctl", targets: ["StableDiffusionControl"])
|
||||
],
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/apple/ml-stable-diffusion", revision: "5d2744e38297b01662b8bdfb41e899ac98036d8b"),
|
||||
@ -32,7 +32,7 @@ let package = Package(
|
||||
.target(name: "StableDiffusionCore"),
|
||||
.product(name: "ArgumentParser", package: "swift-argument-parser")
|
||||
]),
|
||||
.executableTarget(name: "TestStableDiffusionClient", dependencies: [
|
||||
.executableTarget(name: "StableDiffusionControl", dependencies: [
|
||||
.target(name: "StableDiffusionProtos"),
|
||||
.target(name: "StableDiffusionCore"),
|
||||
.product(name: "GRPC", package: "grpc-swift")
|
||||
|
Reference in New Issue
Block a user