Job management and preparation for multi-hosting.

This commit is contained in:
2023-05-08 16:06:07 -07:00
parent a2d9e14f3a
commit ace2c07aa1
30 changed files with 3879 additions and 2307 deletions

View File

@ -5,4 +5,6 @@ public enum SdCoreError: Error {
case imageEncodeFailed
case imageDecodeFailed
case modelNotFound
case jobNotFound
case notImplemented
}

View File

@ -0,0 +1,87 @@
import Combine
import Foundation
import StableDiffusionProtos
public typealias JobUpdateSubject = PassthroughSubject<SdJob, Never>
public actor JobManager {
public let jobUpdateSubject = JobUpdateSubject()
public let jobUpdatePublisher: AsyncPublisher<JobUpdateSubject>
private var jobs: [UInt64: SdJob] = [:]
private var id: UInt64 = 0
public init() {
jobUpdatePublisher = AsyncPublisher(jobUpdateSubject)
}
func nextId() -> UInt64 {
id += 1
return id
}
public func create() -> SdJob {
var job = SdJob()
job.id = nextId()
job.state = .queued
jobs[job.id] = job
return job
}
public func job(id: UInt64) -> SdJob? {
guard let job = jobs[id] else {
return nil
}
return try? SdJob(serializedData: job.serializedData())
}
public func updateJobQueued(_ job: SdJob) {
guard var stored = jobs[job.id] else {
return
}
stored.state = .queued
jobUpdateSubject.send(stored)
jobs[job.id] = stored
}
public func updateJobProgress(_ job: SdJob, progress: Float) {
guard var stored = jobs[job.id] else {
return
}
stored.state = .running
stored.overallPercentageComplete = progress
jobUpdateSubject.send(stored)
jobs[job.id] = stored
}
public func updateJobCompleted(_ job: SdJob) {
guard var stored = jobs[job.id] else {
return
}
stored.state = .completed
stored.overallPercentageComplete = 100.0
jobUpdateSubject.send(stored)
jobs[job.id] = stored
}
public func updateJobRunning(_ job: SdJob) {
guard var stored = jobs[job.id] else {
return
}
stored.state = .running
stored.overallPercentageComplete = 0.0
jobUpdateSubject.send(stored)
jobs[job.id] = stored
}
public func listAllJobs() -> [SdJob] {
var copy: [SdJob] = []
for item in jobs.values {
guard let job = try? SdJob(serializedData: item.serializedData()) else {
continue
}
copy.append(job)
}
return copy
}
}

View File

@ -8,9 +8,11 @@ public actor ModelManager {
private var modelStates: [String: ModelState] = [:]
private let modelBaseURL: URL
private let jobManager: JobManager
public init(modelBaseURL: URL) {
public init(modelBaseURL: URL, jobManager: JobManager) {
self.modelBaseURL = modelBaseURL
self.jobManager = jobManager
}
public func reloadAvailableModels() throws {
@ -67,7 +69,7 @@ public actor ModelManager {
}
if state == nil {
let state = ModelState(url: url)
let state = ModelState(url: url, jobManager: jobManager)
modelStates[name] = state
return state
} else {

View File

@ -5,13 +5,16 @@ import StableDiffusion
import StableDiffusionProtos
public actor ModelState {
private let jobManager: JobManager
private let url: URL
private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer?
private var loadedConfiguration: MLModelConfiguration?
public init(url: URL) {
public init(url: URL, jobManager: JobManager) {
self.url = url
self.jobManager = jobManager
}
public func load(request: SdLoadModelRequest) throws {
@ -39,36 +42,21 @@ public actor ModelState {
loadedConfiguration?.computeUnits.toSdComputeUnits()
}
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
public func generate(_ request: SdGenerateImagesRequest, job: SdJob, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>? = nil) async throws -> SdGenerateImagesResponse {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
response.seeds.append(seed)
}
return response
}
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
guard let pipeline else {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = try toPipelineConfig(request)
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobRunning(job)
}
}
for batch in 1 ... request.batchCount {
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
let eachSegment = 100.0 / Float(request.batchCount)
@ -89,30 +77,51 @@ public actor ModelState {
images = try? cgImagesToImages(request: request, progress.currentImages)
}
let finalImages = images
Task {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
if let finalImages {
update.images = finalImages
let overallPercentage = currentOverallPercentage(percentage)
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobProgress(job, progress: overallPercentage)
}
}
if let stream {
Task {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchProgress = .with { update in
update.percentageComplete = percentage
if let finalImages {
update.images = finalImages
}
}
}
item.overallPercentageComplete = currentOverallPercentage(percentage)
})
item.overallPercentageComplete = overallPercentage
item.jobID = job.id
})
}
}
return true
})
let images = try cgImagesToImages(request: request, cgImages)
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobCompleted(job)
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
})
}
if let stream {
try await stream.send(.with { item in
item.currentBatch = batch
item.batchCompleted = .with { update in
update.images = images
update.seed = seed
}
item.overallPercentageComplete = currentOverallPercentage(100.0)
item.jobID = job.id
})
} else {
response.images.append(contentsOf: images)
response.seeds.append(seed)
}
}
return response
}
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {