Start work on C++ client, and implement streaming of image generation.

This commit is contained in:
2023-04-23 14:22:10 -07:00
parent 1bb629c18f
commit b063d91b1e
11 changed files with 509 additions and 31 deletions

View File

@ -1,5 +1,6 @@
import CoreML
import Foundation
import GRPC
import StableDiffusion
import StableDiffusionProtos
@ -44,7 +45,82 @@ public actor ModelState {
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
response.seeds.append(seed)
}
return response
}
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
for batch in 1 ... request.batchCount {
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
let eachSegment = 100.0 / Float(request.batchCount)
let alreadyCompletedSegments = (Float(batch) - 1) * eachSegment
let percentageToAdd = eachSegment * (batchPercentage / 100.0)
return alreadyCompletedSegments + percentageToAdd
}
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
Task {
do {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
} catch {
fatalError(error.localizedDescription)
}
}
return true
})
let images = try cgImagesToImages(request: request, cgImages)
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
})
}
}
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {
var images: [SdImage] = []
for cgImage in cgImages {
guard let cgImage else { continue }
try images.append(cgImage.toSdImage(format: request.outputImageFormat))
}
return images
}
private func toPipelineConfig(_ request: SdGenerateImagesRequest) throws -> StableDiffusionPipeline.Configuration {
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.imageCount = Int(request.batchSize)
@ -72,22 +148,6 @@ public actor ModelState {
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
default: pipelineConfig.schedulerType = .pndmScheduler
}
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
for cgImage in images {
guard let cgImage else { continue }
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
}
response.seeds.append(seed)
}
return response
return pipelineConfig
}
}