mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
Start work on C++ client, and implement streaming of image generation.
This commit is contained in:
5
Clients/Cpp/.gitignore
vendored
Normal file
5
Clients/Cpp/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
/cmake-build-*
|
||||
/.idea
|
||||
/.vscode
|
||||
/src/*.grpc.*
|
||||
/src/*.pb.*
|
22
Clients/Cpp/CMakeLists.txt
Normal file
22
Clients/Cpp/CMakeLists.txt
Normal file
@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(sdrpc)
|
||||
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
find_package(gRPC CONFIG REQUIRED)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
add_library(sdrpc src/StableDiffusion.proto)
|
||||
|
||||
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
|
||||
|
||||
protobuf_generate(TARGET sdrpc LANGUAGE cpp)
|
||||
protobuf_generate(TARGET sdrpc LANGUAGE grpc
|
||||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
|
||||
target_include_directories(sdrpc PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
target_link_libraries(sdrpc PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||
|
||||
add_executable(sdrpc_sample src/sample.cpp)
|
||||
target_include_directories(sdrpc_sample PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
target_link_libraries(sdrpc_sample PRIVATE sdrpc)
|
1
Clients/Cpp/src/StableDiffusion.proto
Symbolic link
1
Clients/Cpp/src/StableDiffusion.proto
Symbolic link
@ -0,0 +1 @@
|
||||
../../../Common/StableDiffusion.proto
|
11
Clients/Cpp/src/sample.cpp
Normal file
11
Clients/Cpp/src/sample.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
#include <StableDiffusion.pb.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace gay::pizza::stable::diffusion;
|
||||
|
||||
int main() {
|
||||
ModelInfo info;
|
||||
info.set_name("anything-4.5");
|
||||
std::cout << info.DebugString() << std::endl;
|
||||
return 0;
|
||||
}
|
@ -30,6 +30,8 @@ dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
|
||||
|
||||
api("io.grpc:grpc-stub:1.54.1")
|
||||
api("io.grpc:grpc-protobuf:1.54.1")
|
||||
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
||||
|
@ -63,13 +63,20 @@ fun main() {
|
||||
startingImage = image
|
||||
}
|
||||
}.build()
|
||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
|
||||
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
|
||||
if (update.hasBatchProgress()) {
|
||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
||||
}
|
||||
|
||||
println("generated ${generateImagesResponse.imagesCount} images:")
|
||||
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
|
||||
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/image${index}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
if (update.hasBatchCompleted()) {
|
||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/image${imageIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
println("overall progress ${update.overallPercentageComplete}%")
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
|
@ -270,6 +270,29 @@ message GenerateImagesResponse {
|
||||
repeated uint32 seeds = 2;
|
||||
}
|
||||
|
||||
message GenerateImagesBatchProgressUpdate {
|
||||
float percentage_complete = 1;
|
||||
}
|
||||
|
||||
message GenerateImagesBatchCompletedUpdate {
|
||||
repeated Image images = 1;
|
||||
uint32 seed = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a continuous update from an image generation stream.
|
||||
*/
|
||||
message GenerateImagesStreamUpdate {
|
||||
uint32 current_batch = 1;
|
||||
|
||||
oneof update {
|
||||
GenerateImagesBatchProgressUpdate batch_progress = 2;
|
||||
GenerateImagesBatchCompletedUpdate batch_completed = 3;
|
||||
}
|
||||
|
||||
float overall_percentage_complete = 4;
|
||||
}
|
||||
|
||||
/**
|
||||
* The image generation service, for generating images from loaded models.
|
||||
*/
|
||||
@ -278,4 +301,6 @@ service ImageGenerationService {
|
||||
* Generates images using a loaded model.
|
||||
*/
|
||||
rpc GenerateImages(GenerateImagesRequest) returns (GenerateImagesResponse);
|
||||
|
||||
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import CoreML
|
||||
import Foundation
|
||||
import GRPC
|
||||
import StableDiffusion
|
||||
import StableDiffusionProtos
|
||||
|
||||
@ -44,7 +45,82 @@ public actor ModelState {
|
||||
}
|
||||
|
||||
let baseSeed: UInt32 = request.seed
|
||||
var pipelineConfig = try toPipelineConfig(request)
|
||||
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.batchCount {
|
||||
var seed = baseSeed
|
||||
if seed == 0 {
|
||||
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
}
|
||||
pipelineConfig.seed = seed
|
||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
|
||||
response.seeds.append(seed)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
|
||||
guard let pipeline else {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
|
||||
let baseSeed: UInt32 = request.seed
|
||||
var pipelineConfig = try toPipelineConfig(request)
|
||||
|
||||
for batch in 1 ... request.batchCount {
|
||||
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
|
||||
let eachSegment = 100.0 / Float(request.batchCount)
|
||||
let alreadyCompletedSegments = (Float(batch) - 1) * eachSegment
|
||||
let percentageToAdd = eachSegment * (batchPercentage / 100.0)
|
||||
return alreadyCompletedSegments + percentageToAdd
|
||||
}
|
||||
|
||||
var seed = baseSeed
|
||||
if seed == 0 {
|
||||
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
}
|
||||
pipelineConfig.seed = seed
|
||||
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
|
||||
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
|
||||
Task {
|
||||
do {
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchProgress = .with { update in
|
||||
update.percentageComplete = percentage
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
||||
})
|
||||
} catch {
|
||||
fatalError(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
let images = try cgImagesToImages(request: request, cgImages)
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchCompleted = .with { update in
|
||||
update.images = images
|
||||
update.seed = seed
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(100.0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {
|
||||
var images: [SdImage] = []
|
||||
for cgImage in cgImages {
|
||||
guard let cgImage else { continue }
|
||||
try images.append(cgImage.toSdImage(format: request.outputImageFormat))
|
||||
}
|
||||
return images
|
||||
}
|
||||
|
||||
private func toPipelineConfig(_ request: SdGenerateImagesRequest) throws -> StableDiffusionPipeline.Configuration {
|
||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||
pipelineConfig.negativePrompt = request.negativePrompt
|
||||
pipelineConfig.imageCount = Int(request.batchSize)
|
||||
@ -72,22 +148,6 @@ public actor ModelState {
|
||||
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
|
||||
default: pipelineConfig.schedulerType = .pndmScheduler
|
||||
}
|
||||
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.batchCount {
|
||||
var seed = baseSeed
|
||||
if seed == 0 {
|
||||
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
}
|
||||
pipelineConfig.seed = seed
|
||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||
|
||||
for cgImage in images {
|
||||
guard let cgImage else { continue }
|
||||
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
|
||||
}
|
||||
response.seeds.append(seed)
|
||||
}
|
||||
return response
|
||||
return pipelineConfig
|
||||
}
|
||||
}
|
||||
|
@ -295,6 +295,12 @@ public protocol SdImageGenerationServiceClientProtocol: GRPCClient {
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
||||
|
||||
func generateImagesStreaming(
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions?,
|
||||
handler: @escaping (SdGenerateImagesStreamUpdate) -> Void
|
||||
) -> ServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>
|
||||
}
|
||||
|
||||
extension SdImageGenerationServiceClientProtocol {
|
||||
@ -320,6 +326,27 @@ extension SdImageGenerationServiceClientProtocol {
|
||||
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
|
||||
/// Server streaming call to GenerateImagesStreaming
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - request: Request to send to GenerateImagesStreaming.
|
||||
/// - callOptions: Call options.
|
||||
/// - handler: A closure called when each response is received from the server.
|
||||
/// - Returns: A `ServerStreamingCall` with futures for the metadata and status.
|
||||
public func generateImagesStreaming(
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions? = nil,
|
||||
handler: @escaping (SdGenerateImagesStreamUpdate) -> Void
|
||||
) -> ServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate> {
|
||||
return self.makeServerStreamingCall(
|
||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
|
||||
handler: handler
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
@ -393,6 +420,11 @@ public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient {
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
||||
|
||||
func makeGenerateImagesStreamingCall(
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> GRPCAsyncServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@ -416,6 +448,18 @@ extension SdImageGenerationServiceAsyncClientProtocol {
|
||||
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
|
||||
public func makeGenerateImagesStreamingCall(
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> GRPCAsyncServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate> {
|
||||
return self.makeAsyncServerStreamingCall(
|
||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@ -431,6 +475,18 @@ extension SdImageGenerationServiceAsyncClientProtocol {
|
||||
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
|
||||
public func generateImagesStreaming(
|
||||
_ request: SdGenerateImagesRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> GRPCAsyncResponseStream<SdGenerateImagesStreamUpdate> {
|
||||
return self.performAsyncServerStreamingCall(
|
||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@ -456,6 +512,9 @@ public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSe
|
||||
|
||||
/// - Returns: Interceptors to use when invoking 'generateImages'.
|
||||
func makeGenerateImagesInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
||||
|
||||
/// - Returns: Interceptors to use when invoking 'generateImagesStreaming'.
|
||||
func makeGenerateImagesStreamingInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>]
|
||||
}
|
||||
|
||||
public enum SdImageGenerationServiceClientMetadata {
|
||||
@ -464,6 +523,7 @@ public enum SdImageGenerationServiceClientMetadata {
|
||||
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
||||
methods: [
|
||||
SdImageGenerationServiceClientMetadata.Methods.generateImages,
|
||||
SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming,
|
||||
]
|
||||
)
|
||||
|
||||
@ -473,6 +533,12 @@ public enum SdImageGenerationServiceClientMetadata {
|
||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
|
||||
public static let generateImagesStreaming = GRPCMethodDescriptor(
|
||||
name: "GenerateImagesStreaming",
|
||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming",
|
||||
type: GRPCCallType.serverStreaming
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -646,6 +712,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
|
||||
///*
|
||||
/// Generates images using a loaded model.
|
||||
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
||||
|
||||
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
|
||||
}
|
||||
|
||||
extension SdImageGenerationServiceProvider {
|
||||
@ -669,6 +737,15 @@ extension SdImageGenerationServiceProvider {
|
||||
userFunction: self.generateImages(request:context:)
|
||||
)
|
||||
|
||||
case "GenerateImagesStreaming":
|
||||
return ServerStreamingServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdGenerateImagesStreamUpdate>(),
|
||||
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
|
||||
userFunction: self.generateImagesStreaming(request:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@ -692,6 +769,12 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
|
||||
request: SdGenerateImagesRequest,
|
||||
context: GRPCAsyncServerCallContext
|
||||
) async throws -> SdGenerateImagesResponse
|
||||
|
||||
@Sendable func generateImagesStreaming(
|
||||
request: SdGenerateImagesRequest,
|
||||
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,
|
||||
context: GRPCAsyncServerCallContext
|
||||
) async throws
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@ -722,6 +805,15 @@ extension SdImageGenerationServiceAsyncProvider {
|
||||
wrapping: self.generateImages(request:context:)
|
||||
)
|
||||
|
||||
case "GenerateImagesStreaming":
|
||||
return GRPCAsyncServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdGenerateImagesStreamUpdate>(),
|
||||
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
|
||||
wrapping: self.generateImagesStreaming(request:responseStream:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@ -735,6 +827,10 @@ public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol {
|
||||
/// - Returns: Interceptors to use when handling 'generateImages'.
|
||||
/// Defaults to calling `self.makeInterceptors()`.
|
||||
func makeGenerateImagesInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
||||
|
||||
/// - Returns: Interceptors to use when handling 'generateImagesStreaming'.
|
||||
/// Defaults to calling `self.makeInterceptors()`.
|
||||
func makeGenerateImagesStreamingInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>]
|
||||
}
|
||||
|
||||
public enum SdImageGenerationServiceServerMetadata {
|
||||
@ -743,6 +839,7 @@ public enum SdImageGenerationServiceServerMetadata {
|
||||
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
||||
methods: [
|
||||
SdImageGenerationServiceServerMetadata.Methods.generateImages,
|
||||
SdImageGenerationServiceServerMetadata.Methods.generateImagesStreaming,
|
||||
]
|
||||
)
|
||||
|
||||
@ -752,5 +849,11 @@ public enum SdImageGenerationServiceServerMetadata {
|
||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
|
||||
public static let generateImagesStreaming = GRPCMethodDescriptor(
|
||||
name: "GenerateImagesStreaming",
|
||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming",
|
||||
type: GRPCCallType.serverStreaming
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -442,6 +442,90 @@ public struct SdGenerateImagesResponse {
|
||||
public init() {}
|
||||
}
|
||||
|
||||
public struct SdGenerateImagesBatchProgressUpdate {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
public var percentageComplete: Float = 0
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
public struct SdGenerateImagesBatchCompletedUpdate {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
public var images: [SdImage] = []
|
||||
|
||||
public var seed: UInt32 = 0
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a continuous update from an image generation stream.
|
||||
public struct SdGenerateImagesStreamUpdate {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
public var currentBatch: UInt32 = 0
|
||||
|
||||
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
|
||||
|
||||
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
|
||||
get {
|
||||
if case .batchProgress(let v)? = update {return v}
|
||||
return SdGenerateImagesBatchProgressUpdate()
|
||||
}
|
||||
set {update = .batchProgress(newValue)}
|
||||
}
|
||||
|
||||
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
|
||||
get {
|
||||
if case .batchCompleted(let v)? = update {return v}
|
||||
return SdGenerateImagesBatchCompletedUpdate()
|
||||
}
|
||||
set {update = .batchCompleted(newValue)}
|
||||
}
|
||||
|
||||
public var overallPercentageComplete: Float = 0
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public enum OneOf_Update: Equatable {
|
||||
case batchProgress(SdGenerateImagesBatchProgressUpdate)
|
||||
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
|
||||
|
||||
#if !swift(>=4.1)
|
||||
public static func ==(lhs: SdGenerateImagesStreamUpdate.OneOf_Update, rhs: SdGenerateImagesStreamUpdate.OneOf_Update) -> Bool {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch (lhs, rhs) {
|
||||
case (.batchProgress, .batchProgress): return {
|
||||
guard case .batchProgress(let l) = lhs, case .batchProgress(let r) = rhs else { preconditionFailure() }
|
||||
return l == r
|
||||
}()
|
||||
case (.batchCompleted, .batchCompleted): return {
|
||||
guard case .batchCompleted(let l) = lhs, case .batchCompleted(let r) = rhs else { preconditionFailure() }
|
||||
return l == r
|
||||
}()
|
||||
default: return false
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
#if swift(>=5.5) && canImport(_Concurrency)
|
||||
extension SdModelAttention: @unchecked Sendable {}
|
||||
extension SdScheduler: @unchecked Sendable {}
|
||||
@ -455,6 +539,10 @@ extension SdLoadModelRequest: @unchecked Sendable {}
|
||||
extension SdLoadModelResponse: @unchecked Sendable {}
|
||||
extension SdGenerateImagesRequest: @unchecked Sendable {}
|
||||
extension SdGenerateImagesResponse: @unchecked Sendable {}
|
||||
extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
|
||||
#endif // swift(>=5.5) && canImport(_Concurrency)
|
||||
|
||||
// MARK: - Code below here is support for the SwiftProtobuf runtime.
|
||||
@ -837,3 +925,155 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .standard(proto: "percentage_complete"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if self.percentageComplete != 0 {
|
||||
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
|
||||
if lhs.percentageComplete != rhs.percentageComplete {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdGenerateImagesBatchCompletedUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchCompletedUpdate"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .same(proto: "images"),
|
||||
2: .same(proto: "seed"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
|
||||
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if !self.images.isEmpty {
|
||||
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1)
|
||||
}
|
||||
if self.seed != 0 {
|
||||
try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdGenerateImagesBatchCompletedUpdate, rhs: SdGenerateImagesBatchCompletedUpdate) -> Bool {
|
||||
if lhs.images != rhs.images {return false}
|
||||
if lhs.seed != rhs.seed {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesStreamUpdate"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .standard(proto: "current_batch"),
|
||||
2: .standard(proto: "batch_progress"),
|
||||
3: .standard(proto: "batch_completed"),
|
||||
4: .standard(proto: "overall_percentage_complete"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.currentBatch) }()
|
||||
case 2: try {
|
||||
var v: SdGenerateImagesBatchProgressUpdate?
|
||||
var hadOneofValue = false
|
||||
if let current = self.update {
|
||||
hadOneofValue = true
|
||||
if case .batchProgress(let m) = current {v = m}
|
||||
}
|
||||
try decoder.decodeSingularMessageField(value: &v)
|
||||
if let v = v {
|
||||
if hadOneofValue {try decoder.handleConflictingOneOf()}
|
||||
self.update = .batchProgress(v)
|
||||
}
|
||||
}()
|
||||
case 3: try {
|
||||
var v: SdGenerateImagesBatchCompletedUpdate?
|
||||
var hadOneofValue = false
|
||||
if let current = self.update {
|
||||
hadOneofValue = true
|
||||
if case .batchCompleted(let m) = current {v = m}
|
||||
}
|
||||
try decoder.decodeSingularMessageField(value: &v)
|
||||
if let v = v {
|
||||
if hadOneofValue {try decoder.handleConflictingOneOf()}
|
||||
self.update = .batchCompleted(v)
|
||||
}
|
||||
}()
|
||||
case 4: try { try decoder.decodeSingularFloatField(value: &self.overallPercentageComplete) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every if/case branch local when no optimizations
|
||||
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
|
||||
// https://github.com/apple/swift-protobuf/issues/1182
|
||||
if self.currentBatch != 0 {
|
||||
try visitor.visitSingularUInt32Field(value: self.currentBatch, fieldNumber: 1)
|
||||
}
|
||||
switch self.update {
|
||||
case .batchProgress?: try {
|
||||
guard case .batchProgress(let v)? = self.update else { preconditionFailure() }
|
||||
try visitor.visitSingularMessageField(value: v, fieldNumber: 2)
|
||||
}()
|
||||
case .batchCompleted?: try {
|
||||
guard case .batchCompleted(let v)? = self.update else { preconditionFailure() }
|
||||
try visitor.visitSingularMessageField(value: v, fieldNumber: 3)
|
||||
}()
|
||||
case nil: break
|
||||
}
|
||||
if self.overallPercentageComplete != 0 {
|
||||
try visitor.visitSingularFloatField(value: self.overallPercentageComplete, fieldNumber: 4)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdGenerateImagesStreamUpdate, rhs: SdGenerateImagesStreamUpdate) -> Bool {
|
||||
if lhs.currentBatch != rhs.currentBatch {return false}
|
||||
if lhs.update != rhs.update {return false}
|
||||
if lhs.overallPercentageComplete != rhs.overallPercentageComplete {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -11,14 +11,16 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
||||
}
|
||||
|
||||
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
||||
do {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
return try await state.generate(request)
|
||||
} catch {
|
||||
print(error)
|
||||
throw error
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
return try await state.generate(request)
|
||||
}
|
||||
|
||||
func generateImagesStreaming(request: SdGenerateImagesRequest, responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>, context _: GRPCAsyncServerCallContext) async throws {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
try await state.generateStreaming(request, stream: responseStream)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user