Job management and preparation for multi-hosting.

This commit is contained in:
2023-05-08 16:06:07 -07:00
parent a2d9e14f3a
commit ace2c07aa1
30 changed files with 3879 additions and 2307 deletions

View File

@ -6,13 +6,16 @@ find_package(gRPC CONFIG REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_library(sdrpc StableDiffusion.proto)
file(GLOB PROTO_FILES "proto/*.proto")
add_library(sdrpc ${PROTO_FILES})
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET sdrpc LANGUAGE cpp)
protobuf_generate(TARGET sdrpc LANGUAGE cpp IMPORT_DIRS proto)
protobuf_generate(TARGET sdrpc LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
IMPORT_DIRS proto
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
target_include_directories(sdrpc PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sdrpc PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)

View File

@ -1 +0,0 @@
../../Common/StableDiffusion.proto

1
Clients/Cpp/proto Symbolic link
View File

@ -0,0 +1 @@
../../Common

View File

@ -1,10 +1,25 @@
#include "StableDiffusion.pb.h"
#include "StableDiffusion.grpc.pb.h"
#include "host.grpc.pb.h"
#include "image_generation.grpc.pb.h"
#include <grpc++/grpc++.h>
using namespace gay::pizza::stable::diffusion;
int CompareModelInfoByLoadedFirst(ModelInfo& left, ModelInfo& right) {
if (left.is_loaded() && right.is_loaded()) {
return 0;
}
if (left.is_loaded()) {
return 1;
}
if (right.is_loaded()) {
return -1;
}
return 0;
}
int main() {
auto channel = grpc::CreateChannel("localhost:4546", grpc::InsecureChannelCredentials());
auto modelService = ModelService::NewStub(channel);
@ -19,5 +34,9 @@ int main() {
for (const auto &item: models) {
std::cout << "Model Name: " << item.name() << std::endl;
}
std::sort(models.begin(), models.end(), CompareModelInfoByLoadedFirst);
auto model = models.begin();
std::cout << "Chosen Model: " << model->name() << std::endl;
return 0;
}

View File

@ -1,9 +1,9 @@
package gay.pizza.stable.diffusion.sample
import com.google.protobuf.ByteString
import gay.pizza.stable.diffusion.StableDiffusion.*
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import gay.pizza.stable.diffusion.*
import io.grpc.ManagedChannelBuilder
import io.grpc.stub.StreamObserver
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.runBlocking
@ -23,6 +23,25 @@ fun main(args: Array<String>) {
.build()
val client = StableDiffusionRpcClient(channel)
val jobs = mutableMapOf<Long, Job>()
client.jobService.streamJobUpdates(StreamJobUpdatesRequest.getDefaultInstance(), object : StreamObserver<JobUpdate> {
override fun onNext(value: JobUpdate) {
jobs[value.job.id] = value.job
jobs.values.map {
"job=${it.id} status=${it.state.name} completion=${it.overallPercentageComplete}"
}.forEach(::println)
}
override fun onError(throwable: Throwable) {
throwable.printStackTrace()
exitProcess(1)
}
override fun onCompleted() {}
})
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
if (modelListResponse.availableModelsList.isEmpty()) {
println("no available models")

View File

@ -51,4 +51,20 @@ class StableDiffusionRpcClient(val channel: Channel) {
val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy {
TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel)
}
val jobService: JobServiceGrpc.JobServiceStub by lazy {
JobServiceGrpc.newStub(channel)
}
val jobServiceBlocking: JobServiceGrpc.JobServiceBlockingStub by lazy {
JobServiceGrpc.newBlockingStub(channel)
}
val jobServiceFuture: JobServiceGrpc.JobServiceFutureStub by lazy {
JobServiceGrpc.newFutureStub(channel)
}
val jobServiceCoroutine: JobServiceGrpcKt.JobServiceCoroutineStub by lazy {
JobServiceGrpcKt.JobServiceCoroutineStub(channel)
}
}