Files
stable-diffusion-rpc/Sources/StableDiffusionCore/ModelState.swift

94 lines
3.0 KiB
Swift
Raw Normal View History

import CoreML
2023-04-22 14:52:27 -07:00
import Foundation
import StableDiffusion
import StableDiffusionProtos
2023-04-22 14:52:27 -07:00
public actor ModelState {
2023-04-22 14:52:27 -07:00
private let url: URL
private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer?
private var loadedConfiguration: MLModelConfiguration?
2023-04-22 14:52:27 -07:00
2023-04-22 16:32:54 -07:00
public init(url: URL) {
2023-04-22 14:52:27 -07:00
self.url = url
}
public func load(request: SdLoadModelRequest) throws {
2023-04-22 14:52:27 -07:00
let config = MLModelConfiguration()
config.computeUnits = request.computeUnits.toMlComputeUnits()
2023-04-22 14:52:27 -07:00
pipeline = try StableDiffusionPipeline(
resourcesAt: url,
controlNet: [],
configuration: config,
disableSafety: true,
reduceMemory: false
)
let mergesUrl = url.appending(component: "merges.txt")
let vocabUrl = url.appending(component: "vocab.json")
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
2023-04-22 16:32:54 -07:00
try pipeline?.loadResources()
loadedConfiguration = config
}
public func isModelLoaded() -> Bool {
pipeline != nil
}
public func loadedModelComputeUnits() -> SdComputeUnits? {
loadedConfiguration?.computeUnits.toSdComputeUnits()
2023-04-22 14:52:27 -07:00
}
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
2023-04-22 14:52:27 -07:00
guard let pipeline else {
throw SdCoreError.modelNotLoaded
2023-04-22 14:52:27 -07:00
}
let baseSeed: UInt32 = request.seed
2023-04-22 14:52:27 -07:00
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.imageCount = Int(request.batchSize)
if request.hasStartingImage {
pipelineConfig.startingImage = try request.startingImage.toCgImage()
}
if request.guidanceScale != 0.0 {
pipelineConfig.guidanceScale = request.guidanceScale
}
if request.stepCount != 0 {
pipelineConfig.stepCount = Int(request.stepCount)
}
if request.strength != 0.0 {
pipelineConfig.strength = request.strength
}
pipelineConfig.disableSafety = !request.enableSafetyCheck
switch request.scheduler {
case .pndm: pipelineConfig.schedulerType = .pndmScheduler
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
default: pipelineConfig.schedulerType = .pndmScheduler
}
2023-04-22 14:52:27 -07:00
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
2023-04-22 14:52:27 -07:00
let images = try pipeline.generateImages(configuration: pipelineConfig)
2023-04-22 14:52:27 -07:00
for cgImage in images {
guard let cgImage else { continue }
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
2023-04-22 14:52:27 -07:00
}
response.seeds.append(seed)
2023-04-22 14:52:27 -07:00
}
return response
}
}