mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 05:51:32 +00:00
Implement the ability to optionally collect intermediate images.
This commit is contained in:
1
Clients/Java/.gitignore
vendored
1
Clients/Java/.gitignore
vendored
@ -5,5 +5,4 @@ build/
|
|||||||
out/
|
out/
|
||||||
/work
|
/work
|
||||||
/kotlin-js-store
|
/kotlin-js-store
|
||||||
/work
|
|
||||||
/.fleet/*
|
/.fleet/*
|
||||||
|
@ -63,16 +63,22 @@ fun main() {
|
|||||||
startingImage = image
|
startingImage = image
|
||||||
}
|
}
|
||||||
}.build()
|
}.build()
|
||||||
for (update in client.imageGenerationServiceBlocking.generateImagesStreaming(request)) {
|
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||||
if (update.hasBatchProgress()) {
|
if (update.hasBatchProgress()) {
|
||||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
||||||
|
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||||
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
|
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
|
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||||
|
path.writeBytes(image.data.toByteArray())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (update.hasBatchCompleted()) {
|
if (update.hasBatchCompleted()) {
|
||||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
val path = Path("work/image${imageIndex}.${image.format.name}")
|
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||||
path.writeBytes(image.data.toByteArray())
|
path.writeBytes(image.data.toByteArray())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,6 +253,12 @@ message GenerateImagesRequest {
|
|||||||
* If not specified, a reasonable default value is used.
|
* If not specified, a reasonable default value is used.
|
||||||
*/
|
*/
|
||||||
uint32 step_count = 13;
|
uint32 step_count = 13;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates whether to send intermediate images
|
||||||
|
* while in streaming mode.
|
||||||
|
*/
|
||||||
|
bool send_intermediates = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -278,6 +284,14 @@ message GenerateImagesBatchProgressUpdate {
|
|||||||
* The percentage of this batch that is complete.
|
* The percentage of this batch that is complete.
|
||||||
*/
|
*/
|
||||||
float percentage_complete = 1;
|
float percentage_complete = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The current state of the generated images from this batch.
|
||||||
|
* These are not usually completed images, but partial images.
|
||||||
|
* These are only available if the request's send_intermediates
|
||||||
|
* parameter is set to true.
|
||||||
|
*/
|
||||||
|
repeated Image images = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -84,18 +84,22 @@ public actor ModelState {
|
|||||||
pipelineConfig.seed = seed
|
pipelineConfig.seed = seed
|
||||||
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
|
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
|
||||||
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
|
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
|
||||||
|
var images: [SdImage]?
|
||||||
|
if request.sendIntermediates {
|
||||||
|
images = try? cgImagesToImages(request: request, progress.currentImages)
|
||||||
|
}
|
||||||
|
let finalImages = images
|
||||||
Task {
|
Task {
|
||||||
do {
|
try await stream.send(.with { item in
|
||||||
try await stream.send(.with { item in
|
item.currentBatch = batch
|
||||||
item.currentBatch = batch
|
item.batchProgress = .with { update in
|
||||||
item.batchProgress = .with { update in
|
update.percentageComplete = percentage
|
||||||
update.percentageComplete = percentage
|
if let finalImages {
|
||||||
|
update.images = finalImages
|
||||||
}
|
}
|
||||||
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
}
|
||||||
})
|
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
||||||
} catch {
|
})
|
||||||
fatalError(error.localizedDescription)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
@ -327,7 +327,8 @@ extension SdImageGenerationServiceClientProtocol {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server streaming call to GenerateImagesStreaming
|
///*
|
||||||
|
/// Generates images using a loaded model, providing updates along the way.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - request: Request to send to GenerateImagesStreaming.
|
/// - request: Request to send to GenerateImagesStreaming.
|
||||||
@ -713,6 +714,8 @@ public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
|
|||||||
/// Generates images using a loaded model.
|
/// Generates images using a loaded model.
|
||||||
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Generates images using a loaded model, providing updates along the way.
|
||||||
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
|
func generateImagesStreaming(request: SdGenerateImagesRequest, context: StreamingResponseCallContext<SdGenerateImagesStreamUpdate>) -> EventLoopFuture<GRPCStatus>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -770,6 +773,8 @@ public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
|
|||||||
context: GRPCAsyncServerCallContext
|
context: GRPCAsyncServerCallContext
|
||||||
) async throws -> SdGenerateImagesResponse
|
) async throws -> SdGenerateImagesResponse
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Generates images using a loaded model, providing updates along the way.
|
||||||
@Sendable func generateImagesStreaming(
|
@Sendable func generateImagesStreaming(
|
||||||
request: SdGenerateImagesRequest,
|
request: SdGenerateImagesRequest,
|
||||||
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,
|
responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>,
|
||||||
|
@ -415,6 +415,11 @@ public struct SdGenerateImagesRequest {
|
|||||||
/// If not specified, a reasonable default value is used.
|
/// If not specified, a reasonable default value is used.
|
||||||
public var stepCount: UInt32 = 0
|
public var stepCount: UInt32 = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Indicates whether to send intermediate images
|
||||||
|
/// while in streaming mode.
|
||||||
|
public var sendIntermediates: Bool = false
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
@ -442,25 +447,42 @@ public struct SdGenerateImagesResponse {
|
|||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a progress update for an image generation batch.
|
||||||
public struct SdGenerateImagesBatchProgressUpdate {
|
public struct SdGenerateImagesBatchProgressUpdate {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The percentage of this batch that is complete.
|
||||||
public var percentageComplete: Float = 0
|
public var percentageComplete: Float = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The current state of the generated images from this batch.
|
||||||
|
/// These are not usually completed images, but partial images.
|
||||||
|
/// These are only available if the request's send_intermediates
|
||||||
|
/// parameter is set to true.
|
||||||
|
public var images: [SdImage] = []
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a completion of an image generation batch.
|
||||||
public struct SdGenerateImagesBatchCompletedUpdate {
|
public struct SdGenerateImagesBatchCompletedUpdate {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The generated images from this batch.
|
||||||
public var images: [SdImage] = []
|
public var images: [SdImage] = []
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The seed for this batch.
|
||||||
public var seed: UInt32 = 0
|
public var seed: UInt32 = 0
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
@ -475,10 +497,16 @@ public struct SdGenerateImagesStreamUpdate {
|
|||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The current batch number that is processing.
|
||||||
public var currentBatch: UInt32 = 0
|
public var currentBatch: UInt32 = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// An update to the image generation pipeline.
|
||||||
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
|
public var update: SdGenerateImagesStreamUpdate.OneOf_Update? = nil
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Batch progress update.
|
||||||
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
|
public var batchProgress: SdGenerateImagesBatchProgressUpdate {
|
||||||
get {
|
get {
|
||||||
if case .batchProgress(let v)? = update {return v}
|
if case .batchProgress(let v)? = update {return v}
|
||||||
@ -487,6 +515,8 @@ public struct SdGenerateImagesStreamUpdate {
|
|||||||
set {update = .batchProgress(newValue)}
|
set {update = .batchProgress(newValue)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Batch completion update.
|
||||||
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
|
public var batchCompleted: SdGenerateImagesBatchCompletedUpdate {
|
||||||
get {
|
get {
|
||||||
if case .batchCompleted(let v)? = update {return v}
|
if case .batchCompleted(let v)? = update {return v}
|
||||||
@ -499,8 +529,14 @@ public struct SdGenerateImagesStreamUpdate {
|
|||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// An update to the image generation pipeline.
|
||||||
public enum OneOf_Update: Equatable {
|
public enum OneOf_Update: Equatable {
|
||||||
|
///*
|
||||||
|
/// Batch progress update.
|
||||||
case batchProgress(SdGenerateImagesBatchProgressUpdate)
|
case batchProgress(SdGenerateImagesBatchProgressUpdate)
|
||||||
|
///*
|
||||||
|
/// Batch completion update.
|
||||||
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
|
case batchCompleted(SdGenerateImagesBatchCompletedUpdate)
|
||||||
|
|
||||||
#if !swift(>=4.1)
|
#if !swift(>=4.1)
|
||||||
@ -796,6 +832,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
11: .standard(proto: "guidance_scale"),
|
11: .standard(proto: "guidance_scale"),
|
||||||
12: .same(proto: "strength"),
|
12: .same(proto: "strength"),
|
||||||
13: .standard(proto: "step_count"),
|
13: .standard(proto: "step_count"),
|
||||||
|
14: .standard(proto: "send_intermediates"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -817,6 +854,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
|
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
|
||||||
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
|
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
|
||||||
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
|
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
|
||||||
|
case 14: try { try decoder.decodeSingularBoolField(value: &self.sendIntermediates) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -866,6 +904,9 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
if self.stepCount != 0 {
|
if self.stepCount != 0 {
|
||||||
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
|
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
|
||||||
}
|
}
|
||||||
|
if self.sendIntermediates != false {
|
||||||
|
try visitor.visitSingularBoolField(value: self.sendIntermediates, fieldNumber: 14)
|
||||||
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -883,6 +924,7 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
if lhs.guidanceScale != rhs.guidanceScale {return false}
|
if lhs.guidanceScale != rhs.guidanceScale {return false}
|
||||||
if lhs.strength != rhs.strength {return false}
|
if lhs.strength != rhs.strength {return false}
|
||||||
if lhs.stepCount != rhs.stepCount {return false}
|
if lhs.stepCount != rhs.stepCount {return false}
|
||||||
|
if lhs.sendIntermediates != rhs.sendIntermediates {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -930,6 +972,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
|||||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
|
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesBatchProgressUpdate"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .standard(proto: "percentage_complete"),
|
1: .standard(proto: "percentage_complete"),
|
||||||
|
2: .same(proto: "images"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -939,6 +982,7 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
|||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
|
case 1: try { try decoder.decodeSingularFloatField(value: &self.percentageComplete) }()
|
||||||
|
case 2: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -948,11 +992,15 @@ extension SdGenerateImagesBatchProgressUpdate: SwiftProtobuf.Message, SwiftProto
|
|||||||
if self.percentageComplete != 0 {
|
if self.percentageComplete != 0 {
|
||||||
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
|
try visitor.visitSingularFloatField(value: self.percentageComplete, fieldNumber: 1)
|
||||||
}
|
}
|
||||||
|
if !self.images.isEmpty {
|
||||||
|
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 2)
|
||||||
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
|
public static func ==(lhs: SdGenerateImagesBatchProgressUpdate, rhs: SdGenerateImagesBatchProgressUpdate) -> Bool {
|
||||||
if lhs.percentageComplete != rhs.percentageComplete {return false}
|
if lhs.percentageComplete != rhs.percentageComplete {return false}
|
||||||
|
if lhs.images != rhs.images {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user