mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-05 22:41:30 +00:00
Job management and preparation for multi-hosting.
This commit is contained in:
@ -5,13 +5,16 @@ import StableDiffusion
|
||||
import StableDiffusionProtos
|
||||
|
||||
public actor ModelState {
|
||||
private let jobManager: JobManager
|
||||
|
||||
private let url: URL
|
||||
private var pipeline: StableDiffusionPipeline?
|
||||
private var tokenizer: BPETokenizer?
|
||||
private var loadedConfiguration: MLModelConfiguration?
|
||||
|
||||
public init(url: URL) {
|
||||
public init(url: URL, jobManager: JobManager) {
|
||||
self.url = url
|
||||
self.jobManager = jobManager
|
||||
}
|
||||
|
||||
public func load(request: SdLoadModelRequest) throws {
|
||||
@ -39,36 +42,21 @@ public actor ModelState {
|
||||
loadedConfiguration?.computeUnits.toSdComputeUnits()
|
||||
}
|
||||
|
||||
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||
public func generate(_ request: SdGenerateImagesRequest, job: SdJob, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>? = nil) async throws -> SdGenerateImagesResponse {
|
||||
guard let pipeline else {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
|
||||
let baseSeed: UInt32 = request.seed
|
||||
var pipelineConfig = try toPipelineConfig(request)
|
||||
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.batchCount {
|
||||
var seed = baseSeed
|
||||
if seed == 0 {
|
||||
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
}
|
||||
pipelineConfig.seed = seed
|
||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
|
||||
response.seeds.append(seed)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
|
||||
guard let pipeline else {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
|
||||
let baseSeed: UInt32 = request.seed
|
||||
var pipelineConfig = try toPipelineConfig(request)
|
||||
|
||||
DispatchQueue.main.async {
|
||||
Task {
|
||||
await self.jobManager.updateJobRunning(job)
|
||||
}
|
||||
}
|
||||
|
||||
for batch in 1 ... request.batchCount {
|
||||
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
|
||||
let eachSegment = 100.0 / Float(request.batchCount)
|
||||
@ -89,30 +77,51 @@ public actor ModelState {
|
||||
images = try? cgImagesToImages(request: request, progress.currentImages)
|
||||
}
|
||||
let finalImages = images
|
||||
Task {
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchProgress = .with { update in
|
||||
update.percentageComplete = percentage
|
||||
if let finalImages {
|
||||
update.images = finalImages
|
||||
let overallPercentage = currentOverallPercentage(percentage)
|
||||
DispatchQueue.main.async {
|
||||
Task {
|
||||
await self.jobManager.updateJobProgress(job, progress: overallPercentage)
|
||||
}
|
||||
}
|
||||
if let stream {
|
||||
Task {
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchProgress = .with { update in
|
||||
update.percentageComplete = percentage
|
||||
if let finalImages {
|
||||
update.images = finalImages
|
||||
}
|
||||
}
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
||||
})
|
||||
item.overallPercentageComplete = overallPercentage
|
||||
item.jobID = job.id
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
let images = try cgImagesToImages(request: request, cgImages)
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchCompleted = .with { update in
|
||||
update.images = images
|
||||
update.seed = seed
|
||||
DispatchQueue.main.async {
|
||||
Task {
|
||||
await self.jobManager.updateJobCompleted(job)
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(100.0)
|
||||
})
|
||||
}
|
||||
if let stream {
|
||||
try await stream.send(.with { item in
|
||||
item.currentBatch = batch
|
||||
item.batchCompleted = .with { update in
|
||||
update.images = images
|
||||
update.seed = seed
|
||||
}
|
||||
item.overallPercentageComplete = currentOverallPercentage(100.0)
|
||||
item.jobID = job.id
|
||||
})
|
||||
} else {
|
||||
response.images.append(contentsOf: images)
|
||||
response.seeds.append(seed)
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {
|
||||
|
Reference in New Issue
Block a user