diff --git a/Clients/Java/sample/build.gradle.kts b/Clients/Java/sample/build.gradle.kts index 701f347..17bf187 100644 --- a/Clients/Java/sample/build.gradle.kts +++ b/Clients/Java/sample/build.gradle.kts @@ -13,6 +13,8 @@ dependencies { implementation("org.jetbrains.kotlin:kotlin-bom") implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC") + implementation(rootProject) } @@ -24,4 +26,4 @@ java { application { mainClass.set("gay.pizza.stable.diffusion.sample.MainKt") -} \ No newline at end of file +} diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index c248231..c040709 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -4,10 +4,12 @@ import com.google.protobuf.ByteString import gay.pizza.stable.diffusion.StableDiffusion.* import gay.pizza.stable.diffusion.StableDiffusionRpcClient import io.grpc.ManagedChannelBuilder -import kotlin.io.path.Path -import kotlin.io.path.exists -import kotlin.io.path.readBytes -import kotlin.io.path.writeBytes +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.runBlocking +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicInteger +import kotlin.io.path.* import kotlin.system.exitProcess fun main(args: Array) { @@ -87,30 +89,58 @@ fun main(args: Array) { startingImage = image } }.build() - for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) { + + val workingDirectory = Path("work") + if (!workingDirectory.exists()) { + workingDirectory.createDirectories() + } + + runBlocking { + val task1 = async { + performImageGeneration(1, client, request, workingDirectory.resolve("task1")) + } + + val task2 = async { + performImageGeneration(2, client, request, workingDirectory.resolve("task2")) + } + + awaitAll(task1, task2) + } + + channel.shutdownNow() +} + +@OptIn(ExperimentalPathApi::class) +suspend fun performImageGeneration(task: Int, client: StableDiffusionRpcClient, request: GenerateImagesRequest, path: Path) { + val updateIndex = AtomicInteger(0) + if (path.exists()) { + path.deleteRecursively() + } + path.createDirectories() + + client.imageGenerationServiceCoroutine.generateImagesStreaming(request).collect { update -> + updateIndex.incrementAndGet() if (update.hasBatchProgress()) { - println("batch=${update.currentBatch} " + + println("task=$task batch=${update.currentBatch} " + "progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " + "overall=${prettyProgressValue(update.overallPercentageComplete)}%") for ((index, image) in update.batchProgress.imagesList.withIndex()) { val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) - println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") - val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}") - path.writeBytes(image.data.toByteArray()) + println("task=$task image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") + val filePath = path.resolve("intermediate_${imageIndex}_${updateIndex}.${image.format.name}") + filePath.writeBytes(image.data.toByteArray()) } } if (update.hasBatchCompleted()) { for ((index, image) in update.batchCompleted.imagesList.withIndex()) { val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) - println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") - val path = Path("work/final_${imageIndex}.${image.format.name}") - path.writeBytes(image.data.toByteArray()) + println("task=$task image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") + val filePath = path.resolve("final_${imageIndex}.${image.format.name}") + filePath.writeBytes(image.data.toByteArray()) } } } - - channel.shutdownNow() } fun prettyProgressValue(value: Float) = String.format("%.2f", value) diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index 9e83b15..6ef2206 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -333,6 +333,9 @@ message GenerateImagesStreamUpdate { GenerateImagesBatchCompletedUpdate batch_completed = 3; } + /** + * The percentage of completion for the entire submitted job. + */ float overall_percentage_complete = 4; }