Start work on C++ client, and implement streaming of image generation.

This commit is contained in:
2023-04-23 14:22:10 -07:00
parent 1bb629c18f
commit b063d91b1e
11 changed files with 509 additions and 31 deletions

5
Clients/Cpp/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/cmake-build-*
/.idea
/.vscode
/src/*.grpc.*
/src/*.pb.*

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.20)
project(sdrpc)
find_package(Protobuf CONFIG REQUIRED)
find_package(gRPC CONFIG REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_library(sdrpc src/StableDiffusion.proto)
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET sdrpc LANGUAGE cpp)
protobuf_generate(TARGET sdrpc LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")
target_include_directories(sdrpc PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sdrpc PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
add_executable(sdrpc_sample src/sample.cpp)
target_include_directories(sdrpc_sample PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sdrpc_sample PRIVATE sdrpc)

View File

@ -0,0 +1 @@
../../../Common/StableDiffusion.proto

View File

@ -0,0 +1,11 @@
#include <StableDiffusion.pb.h>
#include <iostream>
using namespace gay::pizza::stable::diffusion;
int main() {
ModelInfo info;
info.set_name("anything-4.5");
std::cout << info.DebugString() << std::endl;
return 0;
}

View File

@ -30,6 +30,8 @@ dependencies {
implementation("org.jetbrains.kotlin:kotlin-bom")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC")
api("io.grpc:grpc-stub:1.54.1")
api("io.grpc:grpc-protobuf:1.54.1")
api("io.grpc:grpc-kotlin-stub:1.3.0")

View File

@ -63,13 +63,20 @@ fun main() {
startingImage = image
}
}.build()
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
if (update.hasBatchProgress()) {
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
}
println("generated ${generateImagesResponse.imagesCount} images:")
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${index}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
if (update.hasBatchCompleted()) {
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${imageIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
}
}
println("overall progress ${update.overallPercentageComplete}%")
}
channel.shutdownNow()

View File

@ -270,6 +270,29 @@ message GenerateImagesResponse {
repeated uint32 seeds = 2;
}
message GenerateImagesBatchProgressUpdate {
float percentage_complete = 1;
}
message GenerateImagesBatchCompletedUpdate {
repeated Image images = 1;
uint32 seed = 2;
}
/**
* Represents a continuous update from an image generation stream.
*/
message GenerateImagesStreamUpdate {
uint32 current_batch = 1;
oneof update {
GenerateImagesBatchProgressUpdate batch_progress = 2;
GenerateImagesBatchCompletedUpdate batch_completed = 3;
}
float overall_percentage_complete = 4;
}
/**
* The image generation service, for generating images from loaded models.
*/
@ -278,4 +301,6 @@ service ImageGenerationService {
* Generates images using a loaded model.
*/
rpc GenerateImages(GenerateImagesRequest) returns (GenerateImagesResponse);
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
}

View File

@ -1,5 +1,6 @@
import CoreML
import Foundation
import GRPC
import StableDiffusion
import StableDiffusionProtos
@ -44,7 +45,82 @@ public actor ModelState {
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
response.seeds.append(seed)
}
return response
}
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
for batch in 1 ... request.batchCount {
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
let eachSegment = 100.0 / Float(request.batchCount)
let alreadyCompletedSegments = (Float(batch) - 1) * eachSegment
let percentageToAdd = eachSegment * (batchPercentage / 100.0)
return alreadyCompletedSegments + percentageToAdd
}
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
Task {
do {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
} catch {
fatalError(error.localizedDescription)
}
}
return true
})
let images = try cgImagesToImages(request: request, cgImages)
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
})
}
}
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {
var images: [SdImage] = []
for cgImage in cgImages {
guard let cgImage else { continue }
try images.append(cgImage.toSdImage(format: request.outputImageFormat))
}
return images
}
private func toPipelineConfig(_ request: SdGenerateImagesRequest) throws -> StableDiffusionPipeline.Configuration {
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.imageCount = Int(request.batchSize)
@ -72,22 +148,6 @@ public actor ModelState {
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
default: pipelineConfig.schedulerType = .pndmScheduler
}
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
for cgImage in images {
guard let cgImage else { continue }
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
}
response.seeds.append(seed)
}
return response
return pipelineConfig
}
}

View File

@ -295,6 +295,12 @@ public protocol SdImageGenerationServiceClientProtocol: GRPCClient {
_ request: SdGenerateImagesRequest,
callOptions: CallOptions?
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
func generateImagesStreaming(
_ request: SdGenerateImagesRequest,
callOptions: CallOptions?,
handler: @escaping (SdGenerateImagesStreamUpdate) -> Void
) -> ServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>
}
extension SdImageGenerationServiceClientProtocol {
@ -320,6 +326,27 @@ extension SdImageGenerationServiceClientProtocol {
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
)
}
/// Server streaming call to GenerateImagesStreaming
///
/// - Parameters:
/// - request: Request to send to GenerateImagesStreaming.
/// - callOptions: Call options.
/// - handler: A closure called when each response is received from the server.
/// - Returns: A `ServerStreamingCall` with futures for the metadata and status.
public func generateImagesStreaming(
_ request: SdGenerateImagesRequest,
callOptions: CallOptions? = nil,
handler: @escaping (SdGenerateImagesStreamUpdate) -> Void
) -> ServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate> {
return self.makeServerStreamingCall(
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
handler: handler
)
}
}
#if compiler(>=5.6)
@ -393,6 +420,11 @@ public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient {
_ request: SdGenerateImagesRequest,
callOptions: CallOptions?
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
func makeGenerateImagesStreamingCall(
_ request: SdGenerateImagesRequest,
callOptions: CallOptions?
) -> GRPCAsyncServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@ -416,6 +448,18 @@ extension SdImageGenerationServiceAsyncClientProtocol {
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
)
}
public func makeGenerateImagesStreamingCall(
_ request: SdGenerateImagesRequest,
callOptions: CallOptions? = nil
) -> GRPCAsyncServerStreamingCall<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate> {
return self.makeAsyncServerStreamingCall(
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@ -431,6 +475,18 @@ extension SdImageGenerationServiceAsyncClientProtocol {
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
)
}
public func generateImagesStreaming(
_ request: SdGenerateImagesRequest,
callOptions: CallOptions? = nil
) -> GRPCAsyncResponseStream<SdGenerateImagesStreamUpdate> {
return self.performAsyncServerStreamingCall(
path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@ -456,6 +512,9 @@ public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSe
/// - Returns: Interceptors to use when invoking 'generateImages'.
func makeGenerateImagesInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
/// - Returns: Interceptors to use when invoking 'generateImagesStreaming'.
func makeGenerateImagesStreamingInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>]
}
public enum SdImageGenerationServiceClientMetadata {
@ -464,6 +523,7 @@ public enum SdImageGenerationServiceClientMetadata {
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
methods: [
SdImageGenerationServiceClientMetadata.Methods.generateImages,
SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming,
]
)
@ -473,6 +533,12 @@ public enum SdImageGenerationServiceClientMetadata {
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
type: GRPCCallType.unary
)
public static let generateImagesStreaming = GRPCMethodDescriptor(
name: "GenerateImagesStreaming",
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming",
type: GRPCCallType.serverStreaming
)
}
}
@ -646,6 +712,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
///*
/// Generates images using a loaded model.
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
}
extension SdImageGenerationServiceProvider {
@ -669,6 +737,15 @@ extension SdImageGenerationServiceProvider {
userFunction: self.generateImages(request:context:)
)
case "GenerateImagesStreaming":
return ServerStreamingServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
responseSerializer: ProtobufSerializer<SdGenerateImagesStreamUpdate>(),
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
userFunction: self.generateImagesStreaming(request:context:)
)
default:
return nil
}
@ -692,6 +769,12 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
request: SdGenerateImagesRequest,
context: GRPCAsyncServerCallContext
) async throws -> SdGenerateImagesResponse
@Sendable func generateImagesStreaming(
request: SdGenerateImagesRequest,
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,
context: GRPCAsyncServerCallContext
) async throws
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@ -722,6 +805,15 @@ extension SdImageGenerationServiceAsyncProvider {
wrapping: self.generateImages(request:context:)
)
case "GenerateImagesStreaming":
return GRPCAsyncServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
responseSerializer: ProtobufSerializer<SdGenerateImagesStreamUpdate>(),
interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [],
wrapping: self.generateImagesStreaming(request:responseStream:context:)
)
default:
return nil
}
@ -735,6 +827,10 @@ public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol {
/// - Returns: Interceptors to use when handling 'generateImages'.
/// Defaults to calling `self.makeInterceptors()`.
func makeGenerateImagesInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
/// - Returns: Interceptors to use when handling 'generateImagesStreaming'.
/// Defaults to calling `self.makeInterceptors()`.
func makeGenerateImagesStreamingInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesStreamUpdate>]
}
public enum SdImageGenerationServiceServerMetadata {
@ -743,6 +839,7 @@ public enum SdImageGenerationServiceServerMetadata {
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
methods: [
SdImageGenerationServiceServerMetadata.Methods.generateImages,
SdImageGenerationServiceServerMetadata.Methods.generateImagesStreaming,
]
)
@ -752,5 +849,11 @@ public enum SdImageGenerationServiceServerMetadata {
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
type: GRPCCallType.unary
)
public static let generateImagesStreaming = GRPCMethodDescriptor(
name: "GenerateImagesStreaming",
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming",
type: GRPCCallType.serverStreaming
)
}
}

View File

@ -442,6 +442,90 @@ public struct SdGenerateImagesResponse {
public init() {}
}
public struct SdGenerateImagesBatchProgressUpdate {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
public var percentageComplete: Float = 0
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
public struct SdGenerateImagesBatchCompletedUpdate {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
public var images: [SdImage] = []
public var seed: UInt32 = 0
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
///*
/// Represents a continuous update from an image generation stream.
public struct SdGenerateImagesStreamUpdate {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
public var currentBatch: UInt32 = 0
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
get {
if case .batchProgress(let v)? = update {return v}
return SdGenerateImagesBatchProgressUpdate()
}
set {update = .batchProgress(newValue)}
}
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
get {
if case .batchCompleted(let v)? = update {return v}
return SdGenerateImagesBatchCompletedUpdate()
}
set {update = .batchCompleted(newValue)}
}
public var overallPercentageComplete: Float = 0
public var unknownFields = SwiftProtobuf.UnknownStorage()
public enum OneOf_Update: Equatable {
case batchProgress(SdGenerateImagesBatchProgressUpdate)
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
#if !swift(>=4.1)
public static func ==(lhs: SdGenerateImagesStreamUpdate.OneOf_Update, rhs: SdGenerateImagesStreamUpdate.OneOf_Update) -> Bool {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch (lhs, rhs) {
case (.batchProgress, .batchProgress): return {
guard case .batchProgress(let l) = lhs, case .batchProgress(let r) = rhs else { preconditionFailure() }
return l == r
}()
case (.batchCompleted, .batchCompleted): return {
guard case .batchCompleted(let l) = lhs, case .batchCompleted(let r) = rhs else { preconditionFailure() }
return l == r
}()
default: return false
}
}
#endif
}
public init() {}
}
#if swift(>=5.5) && canImport(_Concurrency)
extension SdModelAttention: @unchecked Sendable {}
extension SdScheduler: @unchecked Sendable {}
@ -455,6 +539,10 @@ extension SdLoadModelRequest: @unchecked Sendable {}
extension SdLoadModelResponse: @unchecked Sendable {}
extension SdGenerateImagesRequest: @unchecked Sendable {}
extension SdGenerateImagesResponse: @unchecked Sendable {}
extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
#endif // swift(>=5.5) && canImport(_Concurrency)
// MARK: - Code below here is support for the SwiftProtobuf runtime.
@ -837,3 +925,155 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag
return true
}
}
extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "percentage_complete"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if self.percentageComplete != 0 {
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
if lhs.percentageComplete != rhs.percentageComplete {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}
extension SdGenerateImagesBatchCompletedUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchCompletedUpdate"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "images"),
2: .same(proto: "seed"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if !self.images.isEmpty {
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1)
}
if self.seed != 0 {
try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdGenerateImagesBatchCompletedUpdate, rhs: SdGenerateImagesBatchCompletedUpdate) -> Bool {
if lhs.images != rhs.images {return false}
if lhs.seed != rhs.seed {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}
extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesStreamUpdate"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "current_batch"),
2: .standard(proto: "batch_progress"),
3: .standard(proto: "batch_completed"),
4: .standard(proto: "overall_percentage_complete"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.currentBatch) }()
case 2: try {
var v: SdGenerateImagesBatchProgressUpdate?
var hadOneofValue = false
if let current = self.update {
hadOneofValue = true
if case .batchProgress(let m) = current {v = m}
}
try decoder.decodeSingularMessageField(value: &v)
if let v = v {
if hadOneofValue {try decoder.handleConflictingOneOf()}
self.update = .batchProgress(v)
}
}()
case 3: try {
var v: SdGenerateImagesBatchCompletedUpdate?
var hadOneofValue = false
if let current = self.update {
hadOneofValue = true
if case .batchCompleted(let m) = current {v = m}
}
try decoder.decodeSingularMessageField(value: &v)
if let v = v {
if hadOneofValue {try decoder.handleConflictingOneOf()}
self.update = .batchCompleted(v)
}
}()
case 4: try { try decoder.decodeSingularFloatField(value: &self.overallPercentageComplete) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every if/case branch local when no optimizations
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
// https://github.com/apple/swift-protobuf/issues/1182
if self.currentBatch != 0 {
try visitor.visitSingularUInt32Field(value: self.currentBatch, fieldNumber: 1)
}
switch self.update {
case .batchProgress?: try {
guard case .batchProgress(let v)? = self.update else { preconditionFailure() }
try visitor.visitSingularMessageField(value: v, fieldNumber: 2)
}()
case .batchCompleted?: try {
guard case .batchCompleted(let v)? = self.update else { preconditionFailure() }
try visitor.visitSingularMessageField(value: v, fieldNumber: 3)
}()
case nil: break
}
if self.overallPercentageComplete != 0 {
try visitor.visitSingularFloatField(value: self.overallPercentageComplete, fieldNumber: 4)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdGenerateImagesStreamUpdate, rhs: SdGenerateImagesStreamUpdate) -> Bool {
if lhs.currentBatch != rhs.currentBatch {return false}
if lhs.update != rhs.update {return false}
if lhs.overallPercentageComplete != rhs.overallPercentageComplete {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}

View File

@ -11,14 +11,16 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
}
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
do {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.generate(request)
} catch {
print(error)
throw error
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.generate(request)
}
func generateImagesStreaming(request: SdGenerateImagesRequest, responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>, context _: GRPCAsyncServerCallContext) async throws {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
try await state.generateStreaming(request, stream: responseStream)
}
}