diff --git a/Clients/Cpp/.gitignore b/Clients/Cpp/.gitignore new file mode 100644 index 0000000..6efa2ee --- /dev/null +++ b/Clients/Cpp/.gitignore @@ -0,0 +1,5 @@ +/cmake-build-* +/.idea +/.vscode +/src/*.grpc.* +/src/*.pb.* diff --git a/Clients/Cpp/CMakeLists.txt b/Clients/Cpp/CMakeLists.txt new file mode 100644 index 0000000..a1ec46d --- /dev/null +++ b/Clients/Cpp/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.20) +project(sdrpc) + +find_package(Protobuf CONFIG REQUIRED) +find_package(gRPC CONFIG REQUIRED) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + +add_library(sdrpc src/StableDiffusion.proto) + +get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) + +protobuf_generate(TARGET sdrpc LANGUAGE cpp) +protobuf_generate(TARGET sdrpc LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") +target_include_directories(sdrpc PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +target_link_libraries(sdrpc PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) + +add_executable(sdrpc_sample src/sample.cpp) +target_include_directories(sdrpc_sample PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_link_libraries(sdrpc_sample PRIVATE sdrpc) diff --git a/Clients/Cpp/src/StableDiffusion.proto b/Clients/Cpp/src/StableDiffusion.proto new file mode 120000 index 0000000..2839d83 --- /dev/null +++ b/Clients/Cpp/src/StableDiffusion.proto @@ -0,0 +1 @@ +../../../Common/StableDiffusion.proto \ No newline at end of file diff --git a/Clients/Cpp/src/sample.cpp b/Clients/Cpp/src/sample.cpp new file mode 100644 index 0000000..b34abd8 --- /dev/null +++ b/Clients/Cpp/src/sample.cpp @@ -0,0 +1,11 @@ +#include +#include + +using namespace gay::pizza::stable::diffusion; + +int main() { + ModelInfo info; + info.set_name("anything-4.5"); + std::cout << info.DebugString() << std::endl; + return 0; +} diff --git a/Clients/Java/build.gradle.kts b/Clients/Java/build.gradle.kts index efafde7..b1351a2 100644 --- a/Clients/Java/build.gradle.kts +++ b/Clients/Java/build.gradle.kts @@ -30,6 +30,8 @@ dependencies { implementation("org.jetbrains.kotlin:kotlin-bom") implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.0-RC") + api("io.grpc:grpc-stub:1.54.1") api("io.grpc:grpc-protobuf:1.54.1") api("io.grpc:grpc-kotlin-stub:1.3.0") diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index aa44a07..4ae18ca 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -63,13 +63,20 @@ fun main() { startingImage = image } }.build() - val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request) + for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) { + if (update.hasBatchProgress()) { + println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%") + } - println("generated ${generateImagesResponse.imagesCount} images:") - for ((index, image) in generateImagesResponse.imagesList.withIndex()) { - println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)") - val path = Path("work/image${index}.${image.format.name}") - path.writeBytes(image.data.toByteArray()) + if (update.hasBatchCompleted()) { + for ((index, image) in update.batchCompleted.imagesList.withIndex()) { + val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) + println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") + val path = Path("work/image${imageIndex}.${image.format.name}") + path.writeBytes(image.data.toByteArray()) + } + } + println("overall progress ${update.overallPercentageComplete}%") } channel.shutdownNow() diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index 50bdd25..d069b4b 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -270,6 +270,29 @@ message GenerateImagesResponse { repeated uint32 seeds = 2; } +message GenerateImagesBatchProgressUpdate { + float percentage_complete = 1; +} + +message GenerateImagesBatchCompletedUpdate { + repeated Image images = 1; + uint32 seed = 2; +} + +/** + * Represents a continuous update from an image generation stream. + */ +message GenerateImagesStreamUpdate { + uint32 current_batch = 1; + + oneof update { + GenerateImagesBatchProgressUpdate batch_progress = 2; + GenerateImagesBatchCompletedUpdate batch_completed = 3; + } + + float overall_percentage_complete = 4; +} + /** * The image generation service, for generating images from loaded models. */ @@ -278,4 +301,6 @@ service ImageGenerationService { * Generates images using a loaded model. */ rpc GenerateImages(GenerateImagesRequest) returns (GenerateImagesResponse); + + rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate); } diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index 866e319..b4cbf9d 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -1,5 +1,6 @@ import CoreML import Foundation +import GRPC import StableDiffusion import StableDiffusionProtos @@ -44,7 +45,82 @@ public actor ModelState { } let baseSeed: UInt32 = request.seed + var pipelineConfig = try toPipelineConfig(request) + var response = SdGenerateImagesResponse() + for _ in 0 ..< request.batchCount { + var seed = baseSeed + if seed == 0 { + seed = UInt32.random(in: 0 ..< UInt32.max) + } + pipelineConfig.seed = seed + let images = try pipeline.generateImages(configuration: pipelineConfig) + try response.images.append(contentsOf: cgImagesToImages(request: request, images)) + response.seeds.append(seed) + } + return response + } + + public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter) async throws { + guard let pipeline else { + throw SdCoreError.modelNotLoaded + } + + let baseSeed: UInt32 = request.seed + var pipelineConfig = try toPipelineConfig(request) + + for batch in 1 ... request.batchCount { + @Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float { + let eachSegment = 100.0 / Float(request.batchCount) + let alreadyCompletedSegments = (Float(batch) - 1) * eachSegment + let percentageToAdd = eachSegment * (batchPercentage / 100.0) + return alreadyCompletedSegments + percentageToAdd + } + + var seed = baseSeed + if seed == 0 { + seed = UInt32.random(in: 0 ..< UInt32.max) + } + pipelineConfig.seed = seed + let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in + let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0 + Task { + do { + try await stream.send(.with { item in + item.currentBatch = batch + item.batchProgress = .with { update in + update.percentageComplete = percentage + } + item.overallPercentageComplete = currentOverallPercentage(percentage) + }) + } catch { + fatalError(error.localizedDescription) + } + } + return true + }) + let images = try cgImagesToImages(request: request, cgImages) + try await stream.send(.with { item in + item.currentBatch = batch + item.batchCompleted = .with { update in + update.images = images + update.seed = seed + } + item.overallPercentageComplete = currentOverallPercentage(100.0) + }) + } + } + + private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] { + var images: [SdImage] = [] + for cgImage in cgImages { + guard let cgImage else { continue } + try images.append(cgImage.toSdImage(format: request.outputImageFormat)) + } + return images + } + + private func toPipelineConfig(_ request: SdGenerateImagesRequest) throws -> StableDiffusionPipeline.Configuration { var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt) pipelineConfig.negativePrompt = request.negativePrompt pipelineConfig.imageCount = Int(request.batchSize) @@ -72,22 +148,6 @@ public actor ModelState { case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler default: pipelineConfig.schedulerType = .pndmScheduler } - - var response = SdGenerateImagesResponse() - for _ in 0 ..< request.batchCount { - var seed = baseSeed - if seed == 0 { - seed = UInt32.random(in: 0 ..< UInt32.max) - } - pipelineConfig.seed = seed - let images = try pipeline.generateImages(configuration: pipelineConfig) - - for cgImage in images { - guard let cgImage else { continue } - try response.images.append(cgImage.toSdImage(format: request.outputImageFormat)) - } - response.seeds.append(seed) - } - return response + return pipelineConfig } } diff --git a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift index e0eff21..389be37 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift @@ -295,6 +295,12 @@ public protocol SdImageGenerationServiceClientProtocol: GRPCClient { _ request: SdGenerateImagesRequest, callOptions: CallOptions? ) -> UnaryCall + + func generateImagesStreaming( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions?, + handler: @escaping (SdGenerateImagesStreamUpdate) -> Void + ) -> ServerStreamingCall } extension SdImageGenerationServiceClientProtocol { @@ -320,6 +326,27 @@ extension SdImageGenerationServiceClientProtocol { interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } + + /// Server streaming call to GenerateImagesStreaming + /// + /// - Parameters: + /// - request: Request to send to GenerateImagesStreaming. + /// - callOptions: Call options. + /// - handler: A closure called when each response is received from the server. + /// - Returns: A `ServerStreamingCall` with futures for the metadata and status. + public func generateImagesStreaming( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil, + handler: @escaping (SdGenerateImagesStreamUpdate) -> Void + ) -> ServerStreamingCall { + return self.makeServerStreamingCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [], + handler: handler + ) + } } #if compiler(>=5.6) @@ -393,6 +420,11 @@ public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient { _ request: SdGenerateImagesRequest, callOptions: CallOptions? ) -> GRPCAsyncUnaryCall + + func makeGenerateImagesStreamingCall( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? + ) -> GRPCAsyncServerStreamingCall } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -416,6 +448,18 @@ extension SdImageGenerationServiceAsyncClientProtocol { interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } + + public func makeGenerateImagesStreamingCall( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncServerStreamingCall { + return self.makeAsyncServerStreamingCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [] + ) + } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -431,6 +475,18 @@ extension SdImageGenerationServiceAsyncClientProtocol { interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } + + public func generateImagesStreaming( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncResponseStream { + return self.performAsyncServerStreamingCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [] + ) + } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -456,6 +512,9 @@ public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSe /// - Returns: Interceptors to use when invoking 'generateImages'. func makeGenerateImagesInterceptors() -> [ClientInterceptor] + + /// - Returns: Interceptors to use when invoking 'generateImagesStreaming'. + func makeGenerateImagesStreamingInterceptors() -> [ClientInterceptor] } public enum SdImageGenerationServiceClientMetadata { @@ -464,6 +523,7 @@ public enum SdImageGenerationServiceClientMetadata { fullName: "gay.pizza.stable.diffusion.ImageGenerationService", methods: [ SdImageGenerationServiceClientMetadata.Methods.generateImages, + SdImageGenerationServiceClientMetadata.Methods.generateImagesStreaming, ] ) @@ -473,6 +533,12 @@ public enum SdImageGenerationServiceClientMetadata { path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages", type: GRPCCallType.unary ) + + public static let generateImagesStreaming = GRPCMethodDescriptor( + name: "GenerateImagesStreaming", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming", + type: GRPCCallType.serverStreaming + ) } } @@ -646,6 +712,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider { ///* /// Generates images using a loaded model. func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture + + func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext) -> EventLoopFuture } extension SdImageGenerationServiceProvider { @@ -669,6 +737,15 @@ extension SdImageGenerationServiceProvider { userFunction: self.generateImages(request:context:) ) + case "GenerateImagesStreaming": + return ServerStreamingServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [], + userFunction: self.generateImagesStreaming(request:context:) + ) + default: return nil } @@ -692,6 +769,12 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider { request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext ) async throws -> SdGenerateImagesResponse + + @Sendable func generateImagesStreaming( + request: SdGenerateImagesRequest, + responseStream: GRPCAsyncResponseStreamWriter, + context: GRPCAsyncServerCallContext + ) async throws } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -722,6 +805,15 @@ extension SdImageGenerationServiceAsyncProvider { wrapping: self.generateImages(request:context:) ) + case "GenerateImagesStreaming": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeGenerateImagesStreamingInterceptors() ?? [], + wrapping: self.generateImagesStreaming(request:responseStream:context:) + ) + default: return nil } @@ -735,6 +827,10 @@ public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol { /// - Returns: Interceptors to use when handling 'generateImages'. /// Defaults to calling `self.makeInterceptors()`. func makeGenerateImagesInterceptors() -> [ServerInterceptor] + + /// - Returns: Interceptors to use when handling 'generateImagesStreaming'. + /// Defaults to calling `self.makeInterceptors()`. + func makeGenerateImagesStreamingInterceptors() -> [ServerInterceptor] } public enum SdImageGenerationServiceServerMetadata { @@ -743,6 +839,7 @@ public enum SdImageGenerationServiceServerMetadata { fullName: "gay.pizza.stable.diffusion.ImageGenerationService", methods: [ SdImageGenerationServiceServerMetadata.Methods.generateImages, + SdImageGenerationServiceServerMetadata.Methods.generateImagesStreaming, ] ) @@ -752,5 +849,11 @@ public enum SdImageGenerationServiceServerMetadata { path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages", type: GRPCCallType.unary ) + + public static let generateImagesStreaming = GRPCMethodDescriptor( + name: "GenerateImagesStreaming", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImagesStreaming", + type: GRPCCallType.serverStreaming + ) } } diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index 203c56c..58e2f23 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -442,6 +442,90 @@ public struct SdGenerateImagesResponse { public init() {} } +public struct SdGenerateImagesBatchProgressUpdate { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var percentageComplete: Float = 0 + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdGenerateImagesBatchCompletedUpdate { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var images: [SdImage] = [] + + public var seed: UInt32 = 0 + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +///* +/// Represents a continuous update from an image generation stream. +public struct SdGenerateImagesStreamUpdate { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var currentBatch: UInt32 = 0 + + public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil + + public var batchProgress: SdGenerateImagesBatchProgressUpdate { + get { + if case .batchProgress(let v)? = update {return v} + return SdGenerateImagesBatchProgressUpdate() + } + set {update = .batchProgress(newValue)} + } + + public var batchCompleted: SdGenerateImagesBatchCompletedUpdate { + get { + if case .batchCompleted(let v)? = update {return v} + return SdGenerateImagesBatchCompletedUpdate() + } + set {update = .batchCompleted(newValue)} + } + + public var overallPercentageComplete: Float = 0 + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public enum OneOf_Update: Equatable { + case batchProgress(SdGenerateImagesBatchProgressUpdate) + case batchCompleted(SdGenerateImagesBatchCompletedUpdate) + + #if !swift(>=4.1) + public static func ==(lhs: SdGenerateImagesStreamUpdate.OneOf_Update, rhs: SdGenerateImagesStreamUpdate.OneOf_Update) -> Bool { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch (lhs, rhs) { + case (.batchProgress, .batchProgress): return { + guard case .batchProgress(let l) = lhs, case .batchProgress(let r) = rhs else { preconditionFailure() } + return l == r + }() + case (.batchCompleted, .batchCompleted): return { + guard case .batchCompleted(let l) = lhs, case .batchCompleted(let r) = rhs else { preconditionFailure() } + return l == r + }() + default: return false + } + } + #endif + } + + public init() {} +} + #if swift(>=5.5) && canImport(_Concurrency) extension SdModelAttention: @unchecked Sendable {} extension SdScheduler: @unchecked Sendable {} @@ -455,6 +539,10 @@ extension SdLoadModelRequest: @unchecked Sendable {} extension SdLoadModelResponse: @unchecked Sendable {} extension SdGenerateImagesRequest: @unchecked Sendable {} extension SdGenerateImagesResponse: @unchecked Sendable {} +extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {} +extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {} +extension SdGenerateImagesStreamUpdate: @unchecked Sendable {} +extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {} #endif // swift(>=5.5) && canImport(_Concurrency) // MARK: - Code below here is support for the SwiftProtobuf runtime. @@ -837,3 +925,155 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag return true } } + +extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "percentage_complete"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if self.percentageComplete != 0 { + try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool { + if lhs.percentageComplete != rhs.percentageComplete {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdGenerateImagesBatchCompletedUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchCompletedUpdate" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "images"), + 2: .same(proto: "seed"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.images.isEmpty { + try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1) + } + if self.seed != 0 { + try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 2) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdGenerateImagesBatchCompletedUpdate, rhs: SdGenerateImagesBatchCompletedUpdate) -> Bool { + if lhs.images != rhs.images {return false} + if lhs.seed != rhs.seed {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".GenerateImagesStreamUpdate" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "current_batch"), + 2: .standard(proto: "batch_progress"), + 3: .standard(proto: "batch_completed"), + 4: .standard(proto: "overall_percentage_complete"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularUInt32Field(value: &self.currentBatch) }() + case 2: try { + var v: SdGenerateImagesBatchProgressUpdate? + var hadOneofValue = false + if let current = self.update { + hadOneofValue = true + if case .batchProgress(let m) = current {v = m} + } + try decoder.decodeSingularMessageField(value: &v) + if let v = v { + if hadOneofValue {try decoder.handleConflictingOneOf()} + self.update = .batchProgress(v) + } + }() + case 3: try { + var v: SdGenerateImagesBatchCompletedUpdate? + var hadOneofValue = false + if let current = self.update { + hadOneofValue = true + if case .batchCompleted(let m) = current {v = m} + } + try decoder.decodeSingularMessageField(value: &v) + if let v = v { + if hadOneofValue {try decoder.handleConflictingOneOf()} + self.update = .batchCompleted(v) + } + }() + case 4: try { try decoder.decodeSingularFloatField(value: &self.overallPercentageComplete) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every if/case branch local when no optimizations + // are enabled. https://github.com/apple/swift-protobuf/issues/1034 and + // https://github.com/apple/swift-protobuf/issues/1182 + if self.currentBatch != 0 { + try visitor.visitSingularUInt32Field(value: self.currentBatch, fieldNumber: 1) + } + switch self.update { + case .batchProgress?: try { + guard case .batchProgress(let v)? = self.update else { preconditionFailure() } + try visitor.visitSingularMessageField(value: v, fieldNumber: 2) + }() + case .batchCompleted?: try { + guard case .batchCompleted(let v)? = self.update else { preconditionFailure() } + try visitor.visitSingularMessageField(value: v, fieldNumber: 3) + }() + case nil: break + } + if self.overallPercentageComplete != 0 { + try visitor.visitSingularFloatField(value: self.overallPercentageComplete, fieldNumber: 4) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdGenerateImagesStreamUpdate, rhs: SdGenerateImagesStreamUpdate) -> Bool { + if lhs.currentBatch != rhs.currentBatch {return false} + if lhs.update != rhs.update {return false} + if lhs.overallPercentageComplete != rhs.overallPercentageComplete {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} diff --git a/Sources/StableDiffusionServer/ImageGenerationService.swift b/Sources/StableDiffusionServer/ImageGenerationService.swift index b06eddd..5cf3f30 100644 --- a/Sources/StableDiffusionServer/ImageGenerationService.swift +++ b/Sources/StableDiffusionServer/ImageGenerationService.swift @@ -11,14 +11,16 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider { } func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse { - do { - guard let state = await modelManager.getModelState(name: request.modelName) else { - throw SdCoreError.modelNotFound - } - return try await state.generate(request) - } catch { - print(error) - throw error + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdCoreError.modelNotFound } + return try await state.generate(request) + } + + func generateImagesStreaming(request: SdGenerateImagesRequest, responseStream: GRPCAsyncResponseStreamWriter, context _: GRPCAsyncServerCallContext) async throws { + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdCoreError.modelNotFound + } + try await state.generateStreaming(request, stream: responseStream) } }