mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
154 lines
5.7 KiB
Swift
154 lines
5.7 KiB
Swift
import CoreML
|
|
import Foundation
|
|
import GRPC
|
|
import StableDiffusion
|
|
import StableDiffusionProtos
|
|
|
|
public actor ModelState {
|
|
private let url: URL
|
|
private var pipeline: StableDiffusionPipeline?
|
|
private var tokenizer: BPETokenizer?
|
|
private var loadedConfiguration: MLModelConfiguration?
|
|
|
|
public init(url: URL) {
|
|
self.url = url
|
|
}
|
|
|
|
public func load(request: SdLoadModelRequest) throws {
|
|
let config = MLModelConfiguration()
|
|
config.computeUnits = request.computeUnits.toMlComputeUnits()
|
|
pipeline = try StableDiffusionPipeline(
|
|
resourcesAt: url,
|
|
controlNet: [],
|
|
configuration: config,
|
|
disableSafety: true,
|
|
reduceMemory: false
|
|
)
|
|
let mergesUrl = url.appending(component: "merges.txt")
|
|
let vocabUrl = url.appending(component: "vocab.json")
|
|
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
|
try pipeline?.loadResources()
|
|
loadedConfiguration = config
|
|
}
|
|
|
|
public func isModelLoaded() -> Bool {
|
|
pipeline != nil
|
|
}
|
|
|
|
public func loadedModelComputeUnits() -> SdComputeUnits? {
|
|
loadedConfiguration?.computeUnits.toSdComputeUnits()
|
|
}
|
|
|
|
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
|
guard let pipeline else {
|
|
throw SdCoreError.modelNotLoaded
|
|
}
|
|
|
|
let baseSeed: UInt32 = request.seed
|
|
var pipelineConfig = try toPipelineConfig(request)
|
|
|
|
var response = SdGenerateImagesResponse()
|
|
for _ in 0 ..< request.batchCount {
|
|
var seed = baseSeed
|
|
if seed == 0 {
|
|
seed = UInt32.random(in: 0 ..< UInt32.max)
|
|
}
|
|
pipelineConfig.seed = seed
|
|
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
|
try response.images.append(contentsOf: cgImagesToImages(request: request, images))
|
|
response.seeds.append(seed)
|
|
}
|
|
return response
|
|
}
|
|
|
|
public func generateStreaming(_ request: SdGenerateImagesRequest, stream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>) async throws {
|
|
guard let pipeline else {
|
|
throw SdCoreError.modelNotLoaded
|
|
}
|
|
|
|
let baseSeed: UInt32 = request.seed
|
|
var pipelineConfig = try toPipelineConfig(request)
|
|
|
|
for batch in 1 ... request.batchCount {
|
|
@Sendable func currentOverallPercentage(_ batchPercentage: Float) -> Float {
|
|
let eachSegment = 100.0 / Float(request.batchCount)
|
|
let alreadyCompletedSegments = (Float(batch) - 1) * eachSegment
|
|
let percentageToAdd = eachSegment * (batchPercentage / 100.0)
|
|
return alreadyCompletedSegments + percentageToAdd
|
|
}
|
|
|
|
var seed = baseSeed
|
|
if seed == 0 {
|
|
seed = UInt32.random(in: 0 ..< UInt32.max)
|
|
}
|
|
pipelineConfig.seed = seed
|
|
let cgImages = try pipeline.generateImages(configuration: pipelineConfig, progressHandler: { progress in
|
|
let percentage = (Float(progress.step) / Float(progress.stepCount)) * 100.0
|
|
Task {
|
|
do {
|
|
try await stream.send(.with { item in
|
|
item.currentBatch = batch
|
|
item.batchProgress = .with { update in
|
|
update.percentageComplete = percentage
|
|
}
|
|
item.overallPercentageComplete = currentOverallPercentage(percentage)
|
|
})
|
|
} catch {
|
|
fatalError(error.localizedDescription)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
let images = try cgImagesToImages(request: request, cgImages)
|
|
try await stream.send(.with { item in
|
|
item.currentBatch = batch
|
|
item.batchCompleted = .with { update in
|
|
update.images = images
|
|
update.seed = seed
|
|
}
|
|
item.overallPercentageComplete = currentOverallPercentage(100.0)
|
|
})
|
|
}
|
|
}
|
|
|
|
private func cgImagesToImages(request: SdGenerateImagesRequest, _ cgImages: [CGImage?]) throws -> [SdImage] {
|
|
var images: [SdImage] = []
|
|
for cgImage in cgImages {
|
|
guard let cgImage else { continue }
|
|
try images.append(cgImage.toSdImage(format: request.outputImageFormat))
|
|
}
|
|
return images
|
|
}
|
|
|
|
private func toPipelineConfig(_ request: SdGenerateImagesRequest) throws -> StableDiffusionPipeline.Configuration {
|
|
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
|
pipelineConfig.negativePrompt = request.negativePrompt
|
|
pipelineConfig.imageCount = Int(request.batchSize)
|
|
|
|
if request.hasStartingImage {
|
|
pipelineConfig.startingImage = try request.startingImage.toCgImage()
|
|
}
|
|
|
|
if request.guidanceScale != 0.0 {
|
|
pipelineConfig.guidanceScale = request.guidanceScale
|
|
}
|
|
|
|
if request.stepCount != 0 {
|
|
pipelineConfig.stepCount = Int(request.stepCount)
|
|
}
|
|
|
|
if request.strength != 0.0 {
|
|
pipelineConfig.strength = request.strength
|
|
}
|
|
|
|
pipelineConfig.disableSafety = !request.enableSafetyCheck
|
|
|
|
switch request.scheduler {
|
|
case .pndm: pipelineConfig.schedulerType = .pndmScheduler
|
|
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
|
|
default: pipelineConfig.schedulerType = .pndmScheduler
|
|
}
|
|
return pipelineConfig
|
|
}
|
|
}
|