mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 05:51:32 +00:00
Document API, make the implementation match the API, and update the same.
This commit is contained in:
@ -26,15 +26,6 @@ java {
|
||||
withSourcesJar()
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
proto {
|
||||
srcDir("../../Common")
|
||||
include("*.proto")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||
|
@ -1,10 +1,13 @@
|
||||
package gay.pizza.stable.diffusion.sample
|
||||
|
||||
import gay.pizza.stable.diffusion.StableDiffusion
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.io.path.Path
|
||||
import kotlin.io.path.writeBytes
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
fun main() {
|
||||
@ -15,20 +18,22 @@ fun main() {
|
||||
|
||||
val client = StableDiffusionRpcClient(channel)
|
||||
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
|
||||
if (modelListResponse.modelsList.isEmpty()) {
|
||||
if (modelListResponse.availableModelsList.isEmpty()) {
|
||||
println("no available models")
|
||||
exitProcess(0)
|
||||
}
|
||||
println("available models:")
|
||||
for (model in modelListResponse.modelsList) {
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
|
||||
for (model in modelListResponse.availableModelsList) {
|
||||
val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else ""
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||
}
|
||||
|
||||
val model = modelListResponse.modelsList.random()
|
||||
val model = modelListResponse.availableModelsList.random()
|
||||
if (!model.isLoaded) {
|
||||
println("loading model ${model.name}...")
|
||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
computeUnits = model.supportedComputeUnitsList.first()
|
||||
}.build())
|
||||
} else {
|
||||
println("using model ${model.name}...")
|
||||
@ -36,14 +41,22 @@ fun main() {
|
||||
|
||||
println("generating images...")
|
||||
|
||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
|
||||
val request = GenerateImagesRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
imageCount = 1
|
||||
outputImageFormat = StableDiffusion.ImageFormat.png
|
||||
batchSize = 2
|
||||
batchCount = 2
|
||||
prompt = "cat"
|
||||
negativePrompt = "bad, low quality, nsfw"
|
||||
}.build())
|
||||
}.build()
|
||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
|
||||
|
||||
println("generated ${generateImagesResponse.imagesCount} images")
|
||||
println("generated ${generateImagesResponse.imagesCount} images:")
|
||||
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
|
||||
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/image${index}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
||||
|
1
Clients/Java/src/main/proto
Symbolic link
1
Clients/Java/src/main/proto
Symbolic link
@ -0,0 +1 @@
|
||||
../../../../Common
|
Reference in New Issue
Block a user