diff --git a/Clients/Java/.gitignore b/Clients/Java/.gitignore index 3704d10..d468ff9 100644 --- a/Clients/Java/.gitignore +++ b/Clients/Java/.gitignore @@ -5,5 +5,4 @@ build/ out/ /work /kotlin-js-store -/work /.fleet/* diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index 4ae18ca..8206d53 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -63,16 +63,22 @@ fun main() { startingImage = image } }.build() - for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) { + for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) { if (update.hasBatchProgress()) { println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%") + for ((index, image) in update.batchProgress.imagesList.withIndex()) { + val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) + println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") + val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}") + path.writeBytes(image.data.toByteArray()) + } } if (update.hasBatchCompleted()) { for ((index, image) in update.batchCompleted.imagesList.withIndex()) { val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") - val path = Path("work/image${imageIndex}.${image.format.name}") + val path = Path("work/final_${imageIndex}.${image.format.name}") path.writeBytes(image.data.toByteArray()) } } diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index a109416..b3fedfe 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -253,6 +253,12 @@ message GenerateImagesRequest { * If not specified, a reasonable default value is used. */ uint32 step_count = 13; + + /** + * Indicates whether to send intermediate images + * while in streaming mode. + */ + bool send_intermediates = 14; } /** @@ -263,7 +269,7 @@ message GenerateImagesResponse { * The set of generated images by the Stable Diffusion pipeline. */ repeated Image images = 1; - + /** * The seeds that were used to generate the images. */ @@ -278,6 +284,14 @@ message GenerateImagesBatchProgressUpdate { * The percentage of this batch that is complete. */ float percentage_complete = 1; + + /** + * The current state of the generated images from this batch. + * These are not usually completed images, but partial images. + * These are only available if the request's send_intermediates + * parameter is set to true. + */ + repeated Image images = 2; } /** diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index b4cbf9d..de3844f 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -84,18 +84,22 @@ public actor ModelState { pipelineConfig.seed = seed let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0 + var images: [SdImage]? + if request.sendIntermediates { + images = try? cgImagesToImages(request: request, progress.currentImages) + } + let finalImages = images Task { - do { - try await stream.send(.with { item in - item.currentBatch = batch - item.batchProgress = .with { update in - update.percentageComplete = percentage + try await stream.send(.with { item in + item.currentBatch = batch + item.batchProgress = .with { update in + update.percentageComplete = percentage + if let finalImages { + update.images = finalImages } - item.overallPercentageComplete = currentOverallPercentage(percentage) - }) - } catch { - fatalError(error.localizedDescription) - } + } + item.overallPercentageComplete = currentOverallPercentage(percentage) + }) } return true }) diff --git a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift index 389be37..96e3a24 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift @@ -327,7 +327,8 @@ extension SdImageGenerationServiceClientProtocol { ) } - /// Server streaming call to GenerateImagesStreaming + ///* + /// Generates images using a loaded model, providing updates along the way. /// /// - Parameters: /// - request: Request to send to GenerateImagesStreaming. @@ -713,6 +714,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider { /// Generates images using a loaded model. func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture + ///* + /// Generates images using a loaded model, providing updates along the way. func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext) -> EventLoopFuture } @@ -770,6 +773,8 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider { context: GRPCAsyncServerCallContext ) async throws -> SdGenerateImagesResponse + ///* + /// Generates images using a loaded model, providing updates along the way. @Sendable func generateImagesStreaming( request: SdGenerateImagesRequest, responseStream: GRPCAsyncResponseStreamWriter, diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index 58e2f23..b05b458 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -415,6 +415,11 @@ public struct SdGenerateImagesRequest { /// If not specified, a reasonable default value is used. public var stepCount: UInt32 = 0 + ///* + /// Indicates whether to send intermediate images + /// while in streaming mode. + public var sendIntermediates: Bool = false + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} @@ -442,25 +447,42 @@ public struct SdGenerateImagesResponse { public init() {} } +///* +/// Represents a progress update for an image generation batch. public struct SdGenerateImagesBatchProgressUpdate { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The percentage of this batch that is complete. public var percentageComplete: Float = 0 + ///* + /// The current state of the generated images from this batch. + /// These are not usually completed images, but partial images. + /// These are only available if the request's send_intermediates + /// parameter is set to true. + public var images: [SdImage] = [] + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents a completion of an image generation batch. public struct SdGenerateImagesBatchCompletedUpdate { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The generated images from this batch. public var images: [SdImage] = [] + ///* + /// The seed for this batch. public var seed: UInt32 = 0 public var unknownFields = SwiftProtobuf.UnknownStorage() @@ -475,10 +497,16 @@ public struct SdGenerateImagesStreamUpdate { // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The current batch number that is processing. public var currentBatch: UInt32 = 0 + ///* + /// An update to the image generation pipeline. public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil + ///* + /// Batch progress update. public var batchProgress: SdGenerateImagesBatchProgressUpdate { get { if case .batchProgress(let v)? = update {return v} @@ -487,6 +515,8 @@ public struct SdGenerateImagesStreamUpdate { set {update = .batchProgress(newValue)} } + ///* + /// Batch completion update. public var batchCompleted: SdGenerateImagesBatchCompletedUpdate { get { if case .batchCompleted(let v)? = update {return v} @@ -499,8 +529,14 @@ public struct SdGenerateImagesStreamUpdate { public var unknownFields = SwiftProtobuf.UnknownStorage() + ///* + /// An update to the image generation pipeline. public enum OneOf_Update: Equatable { + ///* + /// Batch progress update. case batchProgress(SdGenerateImagesBatchProgressUpdate) + ///* + /// Batch completion update. case batchCompleted(SdGenerateImagesBatchCompletedUpdate) #if !swift(>=4.1) @@ -796,6 +832,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message 11: .standard(proto: "guidance_scale"), 12: .same(proto: "strength"), 13: .standard(proto: "step_count"), + 14: .standard(proto: "send_intermediates"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -817,6 +854,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }() case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }() case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }() + case 14: try { try decoder.decodeSingularBoolField(value: &self.sendIntermediates) }() default: break } } @@ -866,6 +904,9 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message if self.stepCount != 0 { try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13) } + if self.sendIntermediates != false { + try visitor.visitSingularBoolField(value: self.sendIntermediates, fieldNumber: 14) + } try unknownFields.traverse(visitor: &visitor) } @@ -883,6 +924,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message if lhs.guidanceScale != rhs.guidanceScale {return false} if lhs.strength != rhs.strength {return false} if lhs.stepCount != rhs.stepCount {return false} + if lhs.sendIntermediates != rhs.sendIntermediates {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -930,6 +972,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .standard(proto: "percentage_complete"), + 2: .same(proto: "images"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -939,6 +982,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }() + case 2: try { try decoder.decodeRepeatedMessageField(value: &self.images) }() default: break } } @@ -948,11 +992,15 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto if self.percentageComplete != 0 { try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1) } + if !self.images.isEmpty { + try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 2) + } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool { if lhs.percentageComplete != rhs.percentageComplete {return false} + if lhs.images != rhs.images {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true }