mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 05:51:32 +00:00
Add support for BPE tokenization.
This commit is contained in:
@ -35,6 +35,8 @@ dependencies {
|
||||
api("io.grpc:grpc-stub:1.54.1")
|
||||
api("io.grpc:grpc-protobuf:1.54.1")
|
||||
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
||||
implementation("com.google.protobuf:protobuf-java:3.22.3")
|
||||
|
||||
implementation("io.grpc:grpc-netty:1.54.1")
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,7 @@
|
||||
package gay.pizza.stable.diffusion.sample
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import gay.pizza.stable.diffusion.StableDiffusion
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.Image
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.*
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.io.path.Path
|
||||
@ -14,7 +10,11 @@ import kotlin.io.path.readBytes
|
||||
import kotlin.io.path.writeBytes
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
fun main() {
|
||||
fun main(args: Array<String>) {
|
||||
val chosenModelName = if (args.isNotEmpty()) args[0] else null
|
||||
val chosenPrompt = if (args.size >= 2) args[1] else "cat"
|
||||
val chosenNegativePrompt = if (args.size >= 3) args[2] else "bad, nsfw, low quality"
|
||||
|
||||
val channel = ManagedChannelBuilder
|
||||
.forAddress("127.0.0.1", 4546)
|
||||
.usePlaintext()
|
||||
@ -32,7 +32,12 @@ fun main() {
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||
}
|
||||
|
||||
val model = modelListResponse.availableModelsList.random()
|
||||
val model = if (chosenModelName == null) {
|
||||
modelListResponse.availableModelsList.random()
|
||||
} else {
|
||||
modelListResponse.availableModelsList.first { it.name == chosenModelName }
|
||||
}
|
||||
|
||||
if (!model.isLoaded) {
|
||||
println("loading model ${model.name}...")
|
||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
@ -43,20 +48,39 @@ fun main() {
|
||||
println("using model ${model.name}...")
|
||||
}
|
||||
|
||||
println("tokenizing prompts...")
|
||||
|
||||
val tokenizePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
input = chosenPrompt
|
||||
}.build())
|
||||
val tokenizeNegativePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
input = chosenNegativePrompt
|
||||
}.build())
|
||||
|
||||
println("tokenize prompt='${chosenPrompt}' " +
|
||||
"tokens=[${tokenizePromptResponse.tokensList.joinToString(", ")}] " +
|
||||
"token_ids=[${tokenizePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||
|
||||
println("tokenize negative_prompt='${chosenNegativePrompt}' " +
|
||||
"tokens=[${tokenizeNegativePromptResponse.tokensList.joinToString(", ")}] " +
|
||||
"token_ids=[${tokenizeNegativePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||
|
||||
println("generating images...")
|
||||
|
||||
val startingImagePath = Path("work/start.png")
|
||||
|
||||
val request = GenerateImagesRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
outputImageFormat = StableDiffusion.ImageFormat.png
|
||||
outputImageFormat = ImageFormat.png
|
||||
batchSize = 2
|
||||
batchCount = 2
|
||||
prompt = "cat"
|
||||
negativePrompt = "bad, low quality, nsfw"
|
||||
prompt = chosenPrompt
|
||||
negativePrompt = chosenNegativePrompt
|
||||
if (startingImagePath.exists()) {
|
||||
val image = Image.newBuilder().apply {
|
||||
format = StableDiffusion.ImageFormat.png
|
||||
format = ImageFormat.png
|
||||
data = ByteString.copyFrom(startingImagePath.readBytes())
|
||||
}.build()
|
||||
|
||||
@ -65,10 +89,12 @@ fun main() {
|
||||
}.build()
|
||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||
if (update.hasBatchProgress()) {
|
||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
||||
println("batch=${update.currentBatch} " +
|
||||
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
||||
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
@ -77,13 +103,14 @@ fun main() {
|
||||
if (update.hasBatchCompleted()) {
|
||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
println("overall progress ${update.overallPercentageComplete}%")
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
||||
|
||||
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
||||
|
@ -2,7 +2,7 @@ package gay.pizza.stable.diffusion
|
||||
|
||||
import io.grpc.Channel
|
||||
|
||||
@Suppress("MemberVisibilityCanBePrivate")
|
||||
@Suppress("MemberVisibilityCanBePrivate", "unused")
|
||||
class StableDiffusionRpcClient(val channel: Channel) {
|
||||
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
|
||||
ModelServiceGrpc.newStub(channel)
|
||||
@ -35,4 +35,20 @@ class StableDiffusionRpcClient(val channel: Channel) {
|
||||
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
|
||||
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerService: TokenizerServiceGrpc.TokenizerServiceStub by lazy {
|
||||
TokenizerServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceBlocking: TokenizerServiceGrpc.TokenizerServiceBlockingStub by lazy {
|
||||
TokenizerServiceGrpc.newBlockingStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceFuture: TokenizerServiceGrpc.TokenizerServiceFutureStub by lazy {
|
||||
TokenizerServiceGrpc.newFutureStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy {
|
||||
TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user