Split out worker related things to a separate service definition.

This commit is contained in:
2023-05-08 22:12:24 -07:00
parent ace2c07aa1
commit 2e5a37ea4b
28 changed files with 1271 additions and 359 deletions

View File

@ -1,25 +1,11 @@
#include "host.grpc.pb.h"
#include "model.grpc.pb.h"
#include "image_generation.grpc.pb.h"
#include <grpc++/grpc++.h>
using namespace gay::pizza::stable::diffusion;
int CompareModelInfoByLoadedFirst(ModelInfo& left, ModelInfo& right) {
if (left.is_loaded() && right.is_loaded()) {
return 0;
}
if (left.is_loaded()) {
return 1;
}
if (right.is_loaded()) {
return -1;
}
return 0;
}
int main() {
auto channel = grpc::CreateChannel("localhost:4546", grpc::InsecureChannelCredentials());
auto modelService = ModelService::NewStub(channel);
@ -34,9 +20,5 @@ int main() {
for (const auto &item: models) {
std::cout << "Model Name: " << item.name() << std::endl;
}
std::sort(models.begin(), models.end(), CompareModelInfoByLoadedFirst);
auto model = models.begin();
std::cout << "Chosen Model: " << model->name() << std::endl;
return 0;
}

View File

@ -49,8 +49,7 @@ fun main(args: Array<String>) {
}
println("available models:")
for (model in modelListResponse.availableModelsList) {
val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else ""
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
println(" model ${model.name} attention=${model.attention}")
}
val model = if (chosenModelName == null) {
@ -59,15 +58,11 @@ fun main(args: Array<String>) {
modelListResponse.availableModelsList.first { it.name == chosenModelName }
}
if (!model.isLoaded) {
println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
computeUnits = model.supportedComputeUnitsList.first()
}.build())
} else {
println("using model ${model.name}...")
}
println("loading model ${model.name}...")
client.hostModelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
modelName = model.name
computeUnits = model.supportedComputeUnitsList.first()
}.build())
println("tokenizing prompts...")

View File

@ -20,6 +20,22 @@ class StableDiffusionRpcClient(val channel: Channel) {
ModelServiceGrpcKt.ModelServiceCoroutineStub(channel)
}
val hostModelService: HostModelServiceGrpc.HostModelServiceStub by lazy {
HostModelServiceGrpc.newStub(channel)
}
val hostModelServiceBlocking: HostModelServiceGrpc.HostModelServiceBlockingStub by lazy {
HostModelServiceGrpc.newBlockingStub(channel)
}
val hostModelServiceFuture: HostModelServiceGrpc.HostModelServiceFutureStub by lazy {
HostModelServiceGrpc.newFutureStub(channel)
}
val hostModelServiceCoroutine: HostModelServiceGrpcKt.HostModelServiceCoroutineStub by lazy {
HostModelServiceGrpcKt.HostModelServiceCoroutineStub(channel)
}
val imageGenerationService: ImageGenerationServiceGrpc.ImageGenerationServiceStub by lazy {
ImageGenerationServiceGrpc.newStub(channel)
}