mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
Clients/Java: Add a multi-task sample that runs two jobs at the same time.
This commit is contained in:
@ -13,6 +13,8 @@ dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
|
||||
|
||||
implementation(rootProject)
|
||||
}
|
||||
|
||||
@ -24,4 +26,4 @@ java {
|
||||
|
||||
application {
|
||||
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
|
||||
}
|
||||
}
|
||||
|
@ -4,10 +4,12 @@ import com.google.protobuf.ByteString
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.*
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.io.path.Path
|
||||
import kotlin.io.path.exists
|
||||
import kotlin.io.path.readBytes
|
||||
import kotlin.io.path.writeBytes
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import java.nio.file.Path
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import kotlin.io.path.*
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
fun main(args: Array<String>) {
|
||||
@ -87,30 +89,58 @@ fun main(args: Array<String>) {
|
||||
startingImage = image
|
||||
}
|
||||
}.build()
|
||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||
|
||||
val workingDirectory = Path("work")
|
||||
if (!workingDirectory.exists()) {
|
||||
workingDirectory.createDirectories()
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
val task1 = async {
|
||||
performImageGeneration(1, client, request, workingDirectory.resolve("task1"))
|
||||
}
|
||||
|
||||
val task2 = async {
|
||||
performImageGeneration(2, client, request, workingDirectory.resolve("task2"))
|
||||
}
|
||||
|
||||
awaitAll(task1, task2)
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalPathApi::class)
|
||||
suspend fun performImageGeneration(task: Int, client: StableDiffusionRpcClient, request: GenerateImagesRequest, path: Path) {
|
||||
val updateIndex = AtomicInteger(0)
|
||||
if (path.exists()) {
|
||||
path.deleteRecursively()
|
||||
}
|
||||
path.createDirectories()
|
||||
|
||||
client.imageGenerationServiceCoroutine.generateImagesStreaming(request).collect { update ->
|
||||
updateIndex.incrementAndGet()
|
||||
if (update.hasBatchProgress()) {
|
||||
println("batch=${update.currentBatch} " +
|
||||
println("task=$task batch=${update.currentBatch} " +
|
||||
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
||||
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
println("task=$task image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val filePath = path.resolve("intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||
filePath.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
|
||||
if (update.hasBatchCompleted()) {
|
||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
println("task=$task image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val filePath = path.resolve("final_${imageIndex}.${image.format.name}")
|
||||
filePath.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
||||
|
||||
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
||||
|
@ -333,6 +333,9 @@ message GenerateImagesStreamUpdate {
|
||||
GenerateImagesBatchCompletedUpdate batch_completed = 3;
|
||||
}
|
||||
|
||||
/**
|
||||
* The percentage of completion for the entire submitted job.
|
||||
*/
|
||||
float overall_percentage_complete = 4;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user