Job management and preparation for multi-hosting.

This commit is contained in:
2023-05-08 16:06:07 -07:00
parent a2d9e14f3a
commit ace2c07aa1
30 changed files with 3879 additions and 2307 deletions

View File

@ -5,13 +5,16 @@ import StableDiffusion
import StableDiffusionProtos
public actor ModelState {
private let jobManager: JobManager
private let url: URL
private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer?
private var loadedConfiguration: MLModelConfiguration?
public init(url: URL) {
public init(url: URL, jobManager: JobManager) {
self.url = url
self.jobManager = jobManager
}
public func load(request: SdLoadModelRequest) throws {
@ -39,36 +42,21 @@ public actor ModelState {
loadedConfiguration?.computeUnits.toSdComputeUnits()
}
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
public func generate(_ request: SdGenerateImagesRequest, job: SdJob, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>? = nil) async throws -> SdGenerateImagesResponse {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
response.seeds.append(seed)
}
return response
}
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobRunning(job)
}
}
for batch in 1 ... request.batchCount {
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
let eachSegment = 100.0 / Float(request.batchCount)
@ -89,30 +77,51 @@ public actor ModelState {
images = try? cgImagesToImages(request: request, progress.currentImages)
}
let finalImages = images
Task {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
if let finalImages {
update.images = finalImages
let overallPercentage = currentOverallPercentage(percentage)
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobProgress(job, progress: overallPercentage)
}
}
if let stream {
Task {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
if let finalImages {
update.images = finalImages
}
}
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
item.overallPercentageComplete = overallPercentage
item.jobID = job.id
})
}
}
return true
})
let images = try cgImagesToImages(request: request, cgImages)
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobCompleted(job)
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
})
}
if let stream {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
item.jobID = job.id
})
} else {
response.images.append(contentsOf: images)
response.seeds.append(seed)
}
}
return response
}
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {