mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 14:01:32 +00:00
Document API, make the implementation match the API, and update the same.
This commit is contained in:
@ -2,6 +2,6 @@ import Foundation
|
||||
|
||||
public enum SdCoreError: Error {
|
||||
case modelNotLoaded
|
||||
case imageEncode
|
||||
case imageEncodeFailed
|
||||
case modelNotFound
|
||||
}
|
||||
|
@ -1,22 +1,38 @@
|
||||
import CoreImage
|
||||
import Foundation
|
||||
import StableDiffusionProtos
|
||||
import UniformTypeIdentifiers
|
||||
|
||||
extension CGImage {
|
||||
func toPngData() throws -> Data {
|
||||
func toImageData(format: SdImageFormat) throws -> Data {
|
||||
guard let data = CFDataCreateMutable(nil, 0) else {
|
||||
throw SdCoreError.imageEncode
|
||||
throw SdCoreError.imageEncodeFailed
|
||||
}
|
||||
|
||||
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
|
||||
throw SdCoreError.imageEncode
|
||||
guard let destination = try CGImageDestinationCreateWithData(data, formatToTypeIdentifier(format) as CFString, 1, nil) else {
|
||||
throw SdCoreError.imageEncodeFailed
|
||||
}
|
||||
|
||||
CGImageDestinationAddImage(destination, self, nil)
|
||||
if CGImageDestinationFinalize(destination) {
|
||||
return data as Data
|
||||
} else {
|
||||
throw SdCoreError.imageEncode
|
||||
throw SdCoreError.imageEncodeFailed
|
||||
}
|
||||
}
|
||||
|
||||
func toSdImage(format: SdImageFormat) throws -> SdImage {
|
||||
let content = try toImageData(format: format)
|
||||
var image = SdImage()
|
||||
image.format = format
|
||||
image.data = content
|
||||
return image
|
||||
}
|
||||
|
||||
private func formatToTypeIdentifier(_ format: SdImageFormat) throws -> String {
|
||||
switch format {
|
||||
case .png: return "public.png"
|
||||
default: throw SdCoreError.imageEncodeFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ public actor ModelManager {
|
||||
self.modelBaseURL = modelBaseURL
|
||||
}
|
||||
|
||||
public func reloadModels() throws {
|
||||
public func reloadAvailableModels() throws {
|
||||
modelInfos.removeAll()
|
||||
modelUrls.removeAll()
|
||||
modelStates.removeAll()
|
||||
@ -26,8 +26,37 @@ public actor ModelManager {
|
||||
}
|
||||
}
|
||||
|
||||
public func listModels() -> [SdModelInfo] {
|
||||
Array(modelInfos.values)
|
||||
public func listAvailableModels() async throws -> [SdModelInfo] {
|
||||
var results: [SdModelInfo] = []
|
||||
for simpleInfo in modelInfos.values {
|
||||
var info = try SdModelInfo(jsonString: simpleInfo.jsonString())
|
||||
if let maybeLoaded = modelStates[info.name] {
|
||||
info.isLoaded = await maybeLoaded.isModelLoaded()
|
||||
if let loadedComputeUnits = await maybeLoaded.loadedModelComputeUnits() {
|
||||
info.loadedComputeUnits = loadedComputeUnits
|
||||
}
|
||||
} else {
|
||||
info.isLoaded = false
|
||||
info.loadedComputeUnits = .init()
|
||||
}
|
||||
|
||||
if info.attention == .splitEinSum {
|
||||
info.supportedComputeUnits = [
|
||||
.cpuAndGpu,
|
||||
.cpuAndNeuralEngine,
|
||||
.cpu,
|
||||
.all
|
||||
]
|
||||
} else {
|
||||
info.supportedComputeUnits = [
|
||||
.cpuAndGpu,
|
||||
.cpu
|
||||
]
|
||||
}
|
||||
|
||||
results.append(info)
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
public func createModelState(name: String) throws -> ModelState {
|
||||
@ -53,13 +82,14 @@ public actor ModelManager {
|
||||
private func addModel(url: URL) throws {
|
||||
var info = SdModelInfo()
|
||||
info.name = url.lastPathComponent
|
||||
let attention = getModelAttention(url)
|
||||
info.attention = attention ?? "unknown"
|
||||
if let attention = getModelAttention(url) {
|
||||
info.attention = attention
|
||||
}
|
||||
modelInfos[info.name] = info
|
||||
modelUrls[info.name] = url
|
||||
}
|
||||
|
||||
private func getModelAttention(_ url: URL) -> String? {
|
||||
private func getModelAttention(_ url: URL) -> SdModelAttention? {
|
||||
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
||||
|
||||
struct ModelMetadata: Decodable {
|
||||
@ -74,7 +104,7 @@ public actor ModelManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original"
|
||||
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? SdModelAttention.splitEinSum : SdModelAttention.original
|
||||
} catch {
|
||||
return nil
|
||||
}
|
||||
|
@ -7,14 +7,15 @@ public actor ModelState {
|
||||
private let url: URL
|
||||
private var pipeline: StableDiffusionPipeline?
|
||||
private var tokenizer: BPETokenizer?
|
||||
private var loadedConfiguration: MLModelConfiguration?
|
||||
|
||||
public init(url: URL) {
|
||||
self.url = url
|
||||
}
|
||||
|
||||
public func load() throws {
|
||||
public func load(request: SdLoadModelRequest) throws {
|
||||
let config = MLModelConfiguration()
|
||||
config.computeUnits = .cpuAndGPU
|
||||
config.computeUnits = request.computeUnits.toMlComputeUnits()
|
||||
pipeline = try StableDiffusionPipeline(
|
||||
resourcesAt: url,
|
||||
controlNet: [],
|
||||
@ -26,6 +27,15 @@ public actor ModelState {
|
||||
let vocabUrl = url.appending(component: "vocab.json")
|
||||
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
||||
try pipeline?.loadResources()
|
||||
loadedConfiguration = config
|
||||
}
|
||||
|
||||
public func isModelLoaded() -> Bool {
|
||||
pipeline != nil
|
||||
}
|
||||
|
||||
public func loadedModelComputeUnits() -> SdComputeUnits? {
|
||||
loadedConfiguration?.computeUnits.toSdComputeUnits()
|
||||
}
|
||||
|
||||
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||
@ -33,20 +43,25 @@ public actor ModelState {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
|
||||
let baseSeed: UInt32 = request.seed
|
||||
|
||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||
pipelineConfig.negativePrompt = request.negativePrompt
|
||||
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
|
||||
pipelineConfig.imageCount = Int(request.batchSize)
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.imageCount {
|
||||
for _ in 0 ..< request.batchCount {
|
||||
var seed = baseSeed
|
||||
if seed == 0 {
|
||||
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
}
|
||||
pipelineConfig.seed = seed
|
||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||
|
||||
for cgImage in images {
|
||||
guard let cgImage else { continue }
|
||||
var image = SdImage()
|
||||
image.content = try cgImage.toPngData()
|
||||
response.images.append(image)
|
||||
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
|
||||
}
|
||||
response.seeds.append(seed)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
Reference in New Issue
Block a user