Document API, make the implementation match the API, and update the same.

This commit is contained in:
2023-04-23 02:09:50 -07:00
parent 71afed326f
commit 7c0b2779f4
15 changed files with 707 additions and 310 deletions

View File

@ -2,6 +2,6 @@ import Foundation
public enum SdCoreError: Error {
case modelNotLoaded
case imageEncode
case imageEncodeFailed
case modelNotFound
}

View File

@ -1,22 +1,38 @@
import CoreImage
import Foundation
import StableDiffusionProtos
import UniformTypeIdentifiers
extension CGImage {
func toPngData() throws -> Data {
func toImageData(format: SdImageFormat) throws -> Data {
guard let data = CFDataCreateMutable(nil, 0) else {
throw SdCoreError.imageEncode
throw SdCoreError.imageEncodeFailed
}
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
throw SdCoreError.imageEncode
guard let destination = try CGImageDestinationCreateWithData(data, formatToTypeIdentifier(format) as CFString, 1, nil) else {
throw SdCoreError.imageEncodeFailed
}
CGImageDestinationAddImage(destination, self, nil)
if CGImageDestinationFinalize(destination) {
return data as Data
} else {
throw SdCoreError.imageEncode
throw SdCoreError.imageEncodeFailed
}
}
func toSdImage(format: SdImageFormat) throws -> SdImage {
let content = try toImageData(format: format)
var image = SdImage()
image.format = format
image.data = content
return image
}
private func formatToTypeIdentifier(_ format: SdImageFormat) throws -> String {
switch format {
case .png: return "public.png"
default: throw SdCoreError.imageEncodeFailed
}
}
}

View File

@ -13,7 +13,7 @@ public actor ModelManager {
self.modelBaseURL = modelBaseURL
}
public func reloadModels() throws {
public func reloadAvailableModels() throws {
modelInfos.removeAll()
modelUrls.removeAll()
modelStates.removeAll()
@ -26,8 +26,37 @@ public actor ModelManager {
}
}
public func listModels() -> [SdModelInfo] {
Array(modelInfos.values)
public func listAvailableModels() async throws -> [SdModelInfo] {
var results: [SdModelInfo] = []
for simpleInfo in modelInfos.values {
var info = try SdModelInfo(jsonString: simpleInfo.jsonString())
if let maybeLoaded = modelStates[info.name] {
info.isLoaded = await maybeLoaded.isModelLoaded()
if let loadedComputeUnits = await maybeLoaded.loadedModelComputeUnits() {
info.loadedComputeUnits = loadedComputeUnits
}
} else {
info.isLoaded = false
info.loadedComputeUnits = .init()
}
if info.attention == .splitEinSum {
info.supportedComputeUnits = [
.cpuAndGpu,
.cpuAndNeuralEngine,
.cpu,
.all
]
} else {
info.supportedComputeUnits = [
.cpuAndGpu,
.cpu
]
}
results.append(info)
}
return results
}
public func createModelState(name: String) throws -> ModelState {
@ -53,13 +82,14 @@ public actor ModelManager {
private func addModel(url: URL) throws {
var info = SdModelInfo()
info.name = url.lastPathComponent
let attention = getModelAttention(url)
info.attention = attention ?? "unknown"
if let attention = getModelAttention(url) {
info.attention = attention
}
modelInfos[info.name] = info
modelUrls[info.name] = url
}
private func getModelAttention(_ url: URL) -> String? {
private func getModelAttention(_ url: URL) -> SdModelAttention? {
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
struct ModelMetadata: Decodable {
@ -74,7 +104,7 @@ public actor ModelManager {
return nil
}
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original"
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? SdModelAttention.splitEinSum : SdModelAttention.original
} catch {
return nil
}

View File

@ -7,14 +7,15 @@ public actor ModelState {
private let url: URL
private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer?
private var loadedConfiguration: MLModelConfiguration?
public init(url: URL) {
self.url = url
}
public func load() throws {
public func load(request: SdLoadModelRequest) throws {
let config = MLModelConfiguration()
config.computeUnits = .cpuAndGPU
config.computeUnits = request.computeUnits.toMlComputeUnits()
pipeline = try StableDiffusionPipeline(
resourcesAt: url,
controlNet: [],
@ -26,6 +27,15 @@ public actor ModelState {
let vocabUrl = url.appending(component: "vocab.json")
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
try pipeline?.loadResources()
loadedConfiguration = config
}
public func isModelLoaded() -> Bool {
pipeline != nil
}
public func loadedModelComputeUnits() -> SdComputeUnits? {
loadedConfiguration?.computeUnits.toSdComputeUnits()
}
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
@ -33,20 +43,25 @@ public actor ModelState {
throw SdCoreError.modelNotLoaded
}
let baseSeed: UInt32 = request.seed
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
pipelineConfig.imageCount = Int(request.batchSize)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.imageCount {
for _ in 0 ..< request.batchCount {
var seed = baseSeed
if seed == 0 {
seed = UInt32.random(in: 0 ..< UInt32.max)
}
pipelineConfig.seed = seed
let images = try pipeline.generateImages(configuration: pipelineConfig)
for cgImage in images {
guard let cgImage else { continue }
var image = SdImage()
image.content = try cgImage.toPngData()
response.images.append(image)
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
}
response.seeds.append(seed)
}
return response
}