Implement the ability to optionally collect intermediate images.

This commit is contained in:
2023-04-23 14:50:45 -07:00
parent d24d299a7d
commit 5704654d1a
6 changed files with 91 additions and 15 deletions

View File

@ -5,5 +5,4 @@ build/
out/
/work
/kotlin-js-store
/work
/.fleet/*

View File

@ -63,16 +63,22 @@ fun main() {
startingImage = image
}
}.build()
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
if (update.hasBatchProgress()) {
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
}
}
if (update.hasBatchCompleted()) {
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/image${imageIndex}.${image.format.name}")
val path = Path("work/final_${imageIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray())
}
}

View File

@ -253,6 +253,12 @@ message GenerateImagesRequest {
* If not specified, a reasonable default value is used.
*/
uint32 step_count = 13;
/**
* Indicates whether to send intermediate images
* while in streaming mode.
*/
bool send_intermediates = 14;
}
/**
@ -263,7 +269,7 @@ message GenerateImagesResponse {
* The set of generated images by the Stable Diffusion pipeline.
*/
repeated Image images = 1;
/**
* The seeds that were used to generate the images.
*/
@ -278,6 +284,14 @@ message GenerateImagesBatchProgressUpdate {
* The percentage of this batch that is complete.
*/
float percentage_complete = 1;
/**
* The current state of the generated images from this batch.
* These are not usually completed images, but partial images.
* These are only available if the request's send_intermediates
* parameter is set to true.
*/
repeated Image images = 2;
}
/**

View File

@ -84,18 +84,22 @@ public actor ModelState {
pipelineConfig.seed = seed
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
var images: [SdImage]?
if request.sendIntermediates {
images = try? cgImagesToImages(request: request, progress.currentImages)
}
let finalImages = images
Task {
do {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
if let finalImages {
update.images = finalImages
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
} catch {
fatalError(error.localizedDescription)
}
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
}
return true
})

View File

@ -327,7 +327,8 @@ extension SdImageGenerationServiceClientProtocol {
)
}
/// Server streaming call to GenerateImagesStreaming
///*
/// Generates images using a loaded model, providing updates along the way.
///
/// - Parameters:
/// - request: Request to send to GenerateImagesStreaming.
@ -713,6 +714,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
/// Generates images using a loaded model.
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
///*
/// Generates images using a loaded model, providing updates along the way.
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
}
@ -770,6 +773,8 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
context: GRPCAsyncServerCallContext
) async throws -> SdGenerateImagesResponse
///*
/// Generates images using a loaded model, providing updates along the way.
@Sendable func generateImagesStreaming(
request: SdGenerateImagesRequest,
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,

View File

@ -415,6 +415,11 @@ public struct SdGenerateImagesRequest {
/// If not specified, a reasonable default value is used.
public var stepCount: UInt32 = 0
///*
/// Indicates whether to send intermediate images
/// while in streaming mode.
public var sendIntermediates: Bool = false
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
@ -442,25 +447,42 @@ public struct SdGenerateImagesResponse {
public init() {}
}
///*
/// Represents a progress update for an image generation batch.
public struct SdGenerateImagesBatchProgressUpdate {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The percentage of this batch that is complete.
public var percentageComplete: Float = 0
///*
/// The current state of the generated images from this batch.
/// These are not usually completed images, but partial images.
/// These are only available if the request's send_intermediates
/// parameter is set to true.
public var images: [SdImage] = []
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
///*
/// Represents a completion of an image generation batch.
public struct SdGenerateImagesBatchCompletedUpdate {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The generated images from this batch.
public var images: [SdImage] = []
///*
/// The seed for this batch.
public var seed: UInt32 = 0
public var unknownFields = SwiftProtobuf.UnknownStorage()
@ -475,10 +497,16 @@ public struct SdGenerateImagesStreamUpdate {
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The current batch number that is processing.
public var currentBatch: UInt32 = 0
///*
/// An update to the image generation pipeline.
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
///*
/// Batch progress update.
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
get {
if case .batchProgress(let v)? = update {return v}
@ -487,6 +515,8 @@ public struct SdGenerateImagesStreamUpdate {
set {update = .batchProgress(newValue)}
}
///*
/// Batch completion update.
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
get {
if case .batchCompleted(let v)? = update {return v}
@ -499,8 +529,14 @@ public struct SdGenerateImagesStreamUpdate {
public var unknownFields = SwiftProtobuf.UnknownStorage()
///*
/// An update to the image generation pipeline.
public enum OneOf_Update: Equatable {
///*
/// Batch progress update.
case batchProgress(SdGenerateImagesBatchProgressUpdate)
///*
/// Batch completion update.
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
#if !swift(>=4.1)
@ -796,6 +832,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
11: .standard(proto: "guidance_scale"),
12: .same(proto: "strength"),
13: .standard(proto: "step_count"),
14: .standard(proto: "send_intermediates"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -817,6 +854,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
case 14: try { try decoder.decodeSingularBoolField(value: &self.sendIntermediates) }()
default: break
}
}
@ -866,6 +904,9 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
if self.stepCount != 0 {
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
}
if self.sendIntermediates != false {
try visitor.visitSingularBoolField(value: self.sendIntermediates, fieldNumber: 14)
}
try unknownFields.traverse(visitor: &visitor)
}
@ -883,6 +924,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
if lhs.guidanceScale != rhs.guidanceScale {return false}
if lhs.strength != rhs.strength {return false}
if lhs.stepCount != rhs.stepCount {return false}
if lhs.sendIntermediates != rhs.sendIntermediates {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
@ -930,6 +972,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "percentage_complete"),
2: .same(proto: "images"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -939,6 +982,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
case 2: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
default: break
}
}
@ -948,11 +992,15 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
if self.percentageComplete != 0 {
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
}
if !self.images.isEmpty {
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
if lhs.percentageComplete != rhs.percentageComplete {return false}
if lhs.images != rhs.images {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}