mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
Clients/Java: Add a multi-task sample that runs two jobs at the same time.
This commit is contained in:
@ -13,6 +13,8 @@ dependencies {
|
|||||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||||
|
|
||||||
|
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
|
||||||
|
|
||||||
implementation(rootProject)
|
implementation(rootProject)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,4 +26,4 @@ java {
|
|||||||
|
|
||||||
application {
|
application {
|
||||||
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
|
mainClass.set("gay.pizza.stable.diffusion.sample.MainKt")
|
||||||
}
|
}
|
||||||
|
@ -4,10 +4,12 @@ import com.google.protobuf.ByteString
|
|||||||
import gay.pizza.stable.diffusion.StableDiffusion.*
|
import gay.pizza.stable.diffusion.StableDiffusion.*
|
||||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||||
import io.grpc.ManagedChannelBuilder
|
import io.grpc.ManagedChannelBuilder
|
||||||
import kotlin.io.path.Path
|
import kotlinx.coroutines.async
|
||||||
import kotlin.io.path.exists
|
import kotlinx.coroutines.awaitAll
|
||||||
import kotlin.io.path.readBytes
|
import kotlinx.coroutines.runBlocking
|
||||||
import kotlin.io.path.writeBytes
|
import java.nio.file.Path
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger
|
||||||
|
import kotlin.io.path.*
|
||||||
import kotlin.system.exitProcess
|
import kotlin.system.exitProcess
|
||||||
|
|
||||||
fun main(args: Array<String>) {
|
fun main(args: Array<String>) {
|
||||||
@ -87,30 +89,58 @@ fun main(args: Array<String>) {
|
|||||||
startingImage = image
|
startingImage = image
|
||||||
}
|
}
|
||||||
}.build()
|
}.build()
|
||||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
|
||||||
|
val workingDirectory = Path("work")
|
||||||
|
if (!workingDirectory.exists()) {
|
||||||
|
workingDirectory.createDirectories()
|
||||||
|
}
|
||||||
|
|
||||||
|
runBlocking {
|
||||||
|
val task1 = async {
|
||||||
|
performImageGeneration(1, client, request, workingDirectory.resolve("task1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
val task2 = async {
|
||||||
|
performImageGeneration(2, client, request, workingDirectory.resolve("task2"))
|
||||||
|
}
|
||||||
|
|
||||||
|
awaitAll(task1, task2)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.shutdownNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalPathApi::class)
|
||||||
|
suspend fun performImageGeneration(task: Int, client: StableDiffusionRpcClient, request: GenerateImagesRequest, path: Path) {
|
||||||
|
val updateIndex = AtomicInteger(0)
|
||||||
|
if (path.exists()) {
|
||||||
|
path.deleteRecursively()
|
||||||
|
}
|
||||||
|
path.createDirectories()
|
||||||
|
|
||||||
|
client.imageGenerationServiceCoroutine.generateImagesStreaming(request).collect { update ->
|
||||||
|
updateIndex.incrementAndGet()
|
||||||
if (update.hasBatchProgress()) {
|
if (update.hasBatchProgress()) {
|
||||||
println("batch=${update.currentBatch} " +
|
println("task=$task batch=${update.currentBatch} " +
|
||||||
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
||||||
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
||||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
println("task=$task image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
val filePath = path.resolve("intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||||
path.writeBytes(image.data.toByteArray())
|
filePath.writeBytes(image.data.toByteArray())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (update.hasBatchCompleted()) {
|
if (update.hasBatchCompleted()) {
|
||||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
println("task=$task image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
val filePath = path.resolve("final_${imageIndex}.${image.format.name}")
|
||||||
path.writeBytes(image.data.toByteArray())
|
filePath.writeBytes(image.data.toByteArray())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.shutdownNow()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
||||||
|
@ -333,6 +333,9 @@ message GenerateImagesStreamUpdate {
|
|||||||
GenerateImagesBatchCompletedUpdate batch_completed = 3;
|
GenerateImagesBatchCompletedUpdate batch_completed = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The percentage of completion for the entire submitted job.
|
||||||
|
*/
|
||||||
float overall_percentage_complete = 4;
|
float overall_percentage_complete = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user