Start work on C++ client, and implement streaming of image generation.

This commit is contained in:
2023-04-23 14:22:10 -07:00
parent 1bb629c18f
commit b063d91b1e
11 changed files with 509 additions and 31 deletions

5
Clients/Cpp/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/cmake-build-*
/.idea
/.vscode
/src/*.grpc.*
/src/*.pb.*

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.20)
project(sdrpc)
find_package(Protobuf CONFIG REQUIRED)
find_package(gRPC CONFIG REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_library(sdrpc src/StableDiffusion.proto)
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET sdrpc LANGUAGE cpp)
protobuf_generate(TARGET sdrpc LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
target_include_directories(sdrpc PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sdrpc PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
add_executable(sdrpc_sample src/sample.cpp)
target_include_directories(sdrpc_sample PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sdrpc_sample PRIVATE sdrpc)

View File

@ -0,0 +1 @@
../../../Common/StableDiffusion.proto

View File

@ -0,0 +1,11 @@
#include <StableDiffusion.pb.h>
#include <iostream>
using namespace gay::pizza::stable::diffusion;
int main() {
ModelInfo info;
info.set_name("anything-4.5");
std::cout << info.DebugString() << std::endl;
return 0;
}

View File

@ -30,6 +30,8 @@ dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
api("io.grpc:grpc-stub:1.54.1")
api("io.grpc:grpc-protobuf:1.54.1")
api("io.grpc:grpc-kotlin-stub:1.3.0")

View File

@ -63,13 +63,20 @@ fun main() {
startingImage = image
}
}.build()
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
if (update.hasBatchProgress()) {
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
}
println("generated ${generateImagesResponse.imagesCount} images:")
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${index}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
if (update.hasBatchCompleted()) {
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${imageIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
}
}
println("overall progress ${update.overallPercentageComplete}%")
}
channel.shutdownNow()