mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-05 06:21:31 +00:00
Split out worker related things to a separate service definition.
This commit is contained in:
@ -1,25 +1,11 @@
|
||||
#include "host.grpc.pb.h"
|
||||
#include "model.grpc.pb.h"
|
||||
#include "image_generation.grpc.pb.h"
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
|
||||
using namespace gay::pizza::stable::diffusion;
|
||||
|
||||
int CompareModelInfoByLoadedFirst(ModelInfo& left, ModelInfo& right) {
|
||||
if (left.is_loaded() && right.is_loaded()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (left.is_loaded()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (right.is_loaded()) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
auto channel = grpc::CreateChannel("localhost:4546", grpc::InsecureChannelCredentials());
|
||||
auto modelService = ModelService::NewStub(channel);
|
||||
@ -34,9 +20,5 @@ int main() {
|
||||
for (const auto &item: models) {
|
||||
std::cout << "Model Name: " << item.name() << std::endl;
|
||||
}
|
||||
|
||||
std::sort(models.begin(), models.end(), CompareModelInfoByLoadedFirst);
|
||||
auto model = models.begin();
|
||||
std::cout << "Chosen Model: " << model->name() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
@ -49,8 +49,7 @@ fun main(args: Array<String>) {
|
||||
}
|
||||
println("available models:")
|
||||
for (model in modelListResponse.availableModelsList) {
|
||||
val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else ""
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||
println(" model ${model.name} attention=${model.attention}")
|
||||
}
|
||||
|
||||
val model = if (chosenModelName == null) {
|
||||
@ -59,15 +58,11 @@ fun main(args: Array<String>) {
|
||||
modelListResponse.availableModelsList.first { it.name == chosenModelName }
|
||||
}
|
||||
|
||||
if (!model.isLoaded) {
|
||||
println("loading model ${model.name}...")
|
||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
computeUnits = model.supportedComputeUnitsList.first()
|
||||
}.build())
|
||||
} else {
|
||||
println("using model ${model.name}...")
|
||||
}
|
||||
println("loading model ${model.name}...")
|
||||
client.hostModelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
computeUnits = model.supportedComputeUnitsList.first()
|
||||
}.build())
|
||||
|
||||
println("tokenizing prompts...")
|
||||
|
||||
|
@ -20,6 +20,22 @@ class StableDiffusionRpcClient(val channel: Channel) {
|
||||
ModelServiceGrpcKt.ModelServiceCoroutineStub(channel)
|
||||
}
|
||||
|
||||
val hostModelService: HostModelServiceGrpc.HostModelServiceStub by lazy {
|
||||
HostModelServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
||||
val hostModelServiceBlocking: HostModelServiceGrpc.HostModelServiceBlockingStub by lazy {
|
||||
HostModelServiceGrpc.newBlockingStub(channel)
|
||||
}
|
||||
|
||||
val hostModelServiceFuture: HostModelServiceGrpc.HostModelServiceFutureStub by lazy {
|
||||
HostModelServiceGrpc.newFutureStub(channel)
|
||||
}
|
||||
|
||||
val hostModelServiceCoroutine: HostModelServiceGrpcKt.HostModelServiceCoroutineStub by lazy {
|
||||
HostModelServiceGrpcKt.HostModelServiceCoroutineStub(channel)
|
||||
}
|
||||
|
||||
val imageGenerationService: ImageGenerationServiceGrpc.ImageGenerationServiceStub by lazy {
|
||||
ImageGenerationServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
Reference in New Issue
Block a user