Clients/Java: Add a multi-task sample that runs two jobs at the same time.

This commit is contained in:
2023-04-24 22:42:45 -07:00
parent f61fe6a18f
commit 3fcd527ba1
3 changed files with 50 additions and 15 deletions

View File

@ -13,6 +13,8 @@ dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom") implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8") implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
implementation(rootProject) implementation(rootProject)
} }
@ -24,4 +26,4 @@ java {
application { application {
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt") mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
} }

View File

@ -4,10 +4,12 @@ import com.google.protobuf.ByteString
import gay.pizza.stable.diffusion.StableDiffusion.* import gay.pizza.stable.diffusion.StableDiffusion.*
import gay.pizza.stable.diffusion.StableDiffusionRpcClient import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder import io.grpc.ManagedChannelBuilder
import kotlin.io.path.Path import kotlinx.coroutines.async
import kotlin.io.path.exists import kotlinx.coroutines.awaitAll
import kotlin.io.path.readBytes import kotlinx.coroutines.runBlocking
import kotlin.io.path.writeBytes import java.nio.file.Path
import java.util.concurrent.atomic.AtomicInteger
import kotlin.io.path.*
import kotlin.system.exitProcess import kotlin.system.exitProcess
fun main(args: Array<String>) { fun main(args: Array<String>) {
@ -87,30 +89,58 @@ fun main(args: Array<String>) {
startingImage = image startingImage = image
} }
}.build() }.build()
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
val workingDirectory = Path("work")
if (!workingDirectory.exists()) {
workingDirectory.createDirectories()
}
runBlocking {
val task1 = async {
performImageGeneration(1, client, request, workingDirectory.resolve("task1"))
}
val task2 = async {
performImageGeneration(2, client, request, workingDirectory.resolve("task2"))
}
awaitAll(task1, task2)
}
channel.shutdownNow()
}
@OptIn(ExperimentalPathApi::class)
suspend fun performImageGeneration(task: Int, client: StableDiffusionRpcClient, request: GenerateImagesRequest, path: Path) {
val updateIndex = AtomicInteger(0)
if (path.exists()) {
path.deleteRecursively()
}
path.createDirectories()
client.imageGenerationServiceCoroutine.generateImagesStreaming(request).collect { update ->
updateIndex.incrementAndGet()
if (update.hasBatchProgress()) { if (update.hasBatchProgress()) {
println("batch=${update.currentBatch} " + println("task=$task batch=${update.currentBatch} " +
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " + "progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
"overall=${prettyProgressValue(update.overallPercentageComplete)}%") "overall=${prettyProgressValue(update.overallPercentageComplete)}%")
for ((index, image) in update.batchProgress.imagesList.withIndex()) { for ((index, image) in update.batchProgress.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") println("task=$task image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}") val filePath = path.resolve("intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray()) filePath.writeBytes(image.data.toByteArray())
} }
} }
if (update.hasBatchCompleted()) { if (update.hasBatchCompleted()) {
for ((index, image) in update.batchCompleted.imagesList.withIndex()) { for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") println("task=$task image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/final_${imageIndex}.${image.format.name}") val filePath = path.resolve("final_${imageIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray()) filePath.writeBytes(image.data.toByteArray())
} }
} }
} }
channel.shutdownNow()
} }
fun prettyProgressValue(value: Float) = String.format("%.2f", value) fun prettyProgressValue(value: Float) = String.format("%.2f", value)

View File

@ -333,6 +333,9 @@ message GenerateImagesStreamUpdate {
GenerateImagesBatchCompletedUpdate batch_completed = 3; GenerateImagesBatchCompletedUpdate batch_completed = 3;
} }
/**
* The percentage of completion for the entire submitted job.
*/
float overall_percentage_complete = 4; float overall_percentage_complete = 4;
} }