mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 05:30:54 +00:00
Implement the ability to optionally collect intermediate images.
This commit is contained in:
1
Clients/Java/.gitignore
vendored
1
Clients/Java/.gitignore
vendored
@ -5,5 +5,4 @@ build/
|
||||
out/
|
||||
/work
|
||||
/kotlin-js-store
|
||||
/work
|
||||
/.fleet/*
|
||||
|
@ -63,16 +63,22 @@ fun main() {
|
||||
startingImage = image
|
||||
}
|
||||
}.build()
|
||||
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
|
||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||
if (update.hasBatchProgress()) {
|
||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
|
||||
if (update.hasBatchCompleted()) {
|
||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/image${imageIndex}.${image.format.name}")
|
||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
|
@ -253,6 +253,12 @@ message GenerateImagesRequest {
|
||||
* If not specified, a reasonable default value is used.
|
||||
*/
|
||||
uint32 step_count = 13;
|
||||
|
||||
/**
|
||||
* Indicates whether to send intermediate images
|
||||
* while in streaming mode.
|
||||
*/
|
||||
bool send_intermediates = 14;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -263,7 +269,7 @@ message GenerateImagesResponse {
|
||||
* The set of generated images by the Stable Diffusion pipeline.
|
||||
*/
|
||||
repeated Image images = 1;
|
||||
|
||||
|
||||
/**
|
||||
* The seeds that were used to generate the images.
|
||||
*/
|
||||
@ -278,6 +284,14 @@ message GenerateImagesBatchProgressUpdate {
|
||||
* The percentage of this batch that is complete.
|
||||
*/
|
||||
float percentage_complete = 1;
|
||||
|
||||
/**
|
||||
* The current state of the generated images from this batch.
|
||||
* These are not usually completed images, but partial images.
|
||||
* These are only available if the request's send_intermediates
|
||||
* parameter is set to true.
|
||||
*/
|
||||
repeated Image images = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -84,18 +84,22 @@ public actor ModelState {
|
||||
pipelineConfig.seed = seed
|
||||
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
|
||||
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
|
||||
var images: [SdImage]?
|
||||
if request.sendIntermediates {
|
||||
images = try? cgImagesToImages(request: request, progress.currentImages)
|
||||
}
|
||||
let finalImages = images
|
||||
Task {
|
||||
do {
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchProgress = .with { update in
|
||||
update.percentageComplete = percentage
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchProgress = .with { update in
|
||||
update.percentageComplete = percentage
|
||||
if let finalImages {
|
||||
update.images = finalImages
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
||||
})
|
||||
} catch {
|
||||
fatalError(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
@ -327,7 +327,8 @@ extension SdImageGenerationServiceClientProtocol {
|
||||
)
|
||||
}
|
||||
|
||||
/// Server streaming call to GenerateImagesStreaming
|
||||
///*
|
||||
/// Generates images using a loaded model, providing updates along the way.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - request: Request to send to GenerateImagesStreaming.
|
||||
@ -713,6 +714,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
|
||||
/// Generates images using a loaded model.
|
||||
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
||||
|
||||
///*
|
||||
/// Generates images using a loaded model, providing updates along the way.
|
||||
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
|
||||
}
|
||||
|
||||
@ -770,6 +773,8 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
|
||||
context: GRPCAsyncServerCallContext
|
||||
) async throws -> SdGenerateImagesResponse
|
||||
|
||||
///*
|
||||
/// Generates images using a loaded model, providing updates along the way.
|
||||
@Sendable func generateImagesStreaming(
|
||||
request: SdGenerateImagesRequest,
|
||||
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,
|
||||
|
@ -415,6 +415,11 @@ public struct SdGenerateImagesRequest {
|
||||
/// If not specified, a reasonable default value is used.
|
||||
public var stepCount: UInt32 = 0
|
||||
|
||||
///*
|
||||
/// Indicates whether to send intermediate images
|
||||
/// while in streaming mode.
|
||||
public var sendIntermediates: Bool = false
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
@ -442,25 +447,42 @@ public struct SdGenerateImagesResponse {
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a progress update for an image generation batch.
|
||||
public struct SdGenerateImagesBatchProgressUpdate {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The percentage of this batch that is complete.
|
||||
public var percentageComplete: Float = 0
|
||||
|
||||
///*
|
||||
/// The current state of the generated images from this batch.
|
||||
/// These are not usually completed images, but partial images.
|
||||
/// These are only available if the request's send_intermediates
|
||||
/// parameter is set to true.
|
||||
public var images: [SdImage] = []
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a completion of an image generation batch.
|
||||
public struct SdGenerateImagesBatchCompletedUpdate {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The generated images from this batch.
|
||||
public var images: [SdImage] = []
|
||||
|
||||
///*
|
||||
/// The seed for this batch.
|
||||
public var seed: UInt32 = 0
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
@ -475,10 +497,16 @@ public struct SdGenerateImagesStreamUpdate {
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The current batch number that is processing.
|
||||
public var currentBatch: UInt32 = 0
|
||||
|
||||
///*
|
||||
/// An update to the image generation pipeline.
|
||||
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
|
||||
|
||||
///*
|
||||
/// Batch progress update.
|
||||
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
|
||||
get {
|
||||
if case .batchProgress(let v)? = update {return v}
|
||||
@ -487,6 +515,8 @@ public struct SdGenerateImagesStreamUpdate {
|
||||
set {update = .batchProgress(newValue)}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Batch completion update.
|
||||
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
|
||||
get {
|
||||
if case .batchCompleted(let v)? = update {return v}
|
||||
@ -499,8 +529,14 @@ public struct SdGenerateImagesStreamUpdate {
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
///*
|
||||
/// An update to the image generation pipeline.
|
||||
public enum OneOf_Update: Equatable {
|
||||
///*
|
||||
/// Batch progress update.
|
||||
case batchProgress(SdGenerateImagesBatchProgressUpdate)
|
||||
///*
|
||||
/// Batch completion update.
|
||||
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
|
||||
|
||||
#if !swift(>=4.1)
|
||||
@ -796,6 +832,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
11: .standard(proto: "guidance_scale"),
|
||||
12: .same(proto: "strength"),
|
||||
13: .standard(proto: "step_count"),
|
||||
14: .standard(proto: "send_intermediates"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
@ -817,6 +854,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
|
||||
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
|
||||
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
|
||||
case 14: try { try decoder.decodeSingularBoolField(value: &self.sendIntermediates) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
@ -866,6 +904,9 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
if self.stepCount != 0 {
|
||||
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
|
||||
}
|
||||
if self.sendIntermediates != false {
|
||||
try visitor.visitSingularBoolField(value: self.sendIntermediates, fieldNumber: 14)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
@ -883,6 +924,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
if lhs.guidanceScale != rhs.guidanceScale {return false}
|
||||
if lhs.strength != rhs.strength {return false}
|
||||
if lhs.stepCount != rhs.stepCount {return false}
|
||||
if lhs.sendIntermediates != rhs.sendIntermediates {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
@ -930,6 +972,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .standard(proto: "percentage_complete"),
|
||||
2: .same(proto: "images"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
@ -939,6 +982,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
|
||||
case 2: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
@ -948,11 +992,15 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
||||
if self.percentageComplete != 0 {
|
||||
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
|
||||
}
|
||||
if !self.images.isEmpty {
|
||||
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
|
||||
if lhs.percentageComplete != rhs.percentageComplete {return false}
|
||||
if lhs.images != rhs.images {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
|
Reference in New Issue
Block a user