Initial Commit

This commit is contained in:
2023-04-22 14:52:27 -07:00
commit 2759c8d7fb
16 changed files with 1806 additions and 0 deletions

View File

@ -0,0 +1,7 @@
import Foundation
enum SdServerError: Error {
case modelNotLoaded
case imageEncode
case modelNotFound
}

View File

@ -0,0 +1,22 @@
import Foundation
import CoreImage
import UniformTypeIdentifiers
extension CGImage {
func toPngData() throws -> Data {
guard let data = CFDataCreateMutable(nil, 0) else {
throw SdServerError.imageEncode
}
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
throw SdServerError.imageEncode
}
CGImageDestinationAddImage(destination, self, nil)
if CGImageDestinationFinalize(destination) {
return data as Data
} else {
throw SdServerError.imageEncode
}
}
}

View File

@ -0,0 +1,18 @@
import Foundation
import GRPC
import StableDiffusionProtos
class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func generateImage(request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdServerError.modelNotFound
}
return try await state.generate(request)
}
}

View File

@ -0,0 +1,66 @@
import Foundation
import StableDiffusion
import StableDiffusionProtos
actor ModelManager {
private var modelInfos: [String : SdModelInfo] = [:]
private var modelUrls: [String : URL] = [:]
private var modelStates: [String : ModelState] = [:]
private let modelBaseURL: URL
public init(modelBaseURL: URL) {
self.modelBaseURL = modelBaseURL
}
func reloadModels() throws {
modelInfos.removeAll()
modelStates.removeAll()
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
for subdirectoryURL in contents {
let values = try subdirectoryURL.resourceValues(forKeys: [.isDirectoryKey])
if values.isDirectory ?? false {
try addModel(url: subdirectoryURL)
}
}
}
func listModels() -> [SdModelInfo] {
return Array(modelInfos.values)
}
func getModelState(name: String) -> ModelState? {
return modelStates[name]
}
private func addModel(url: URL) throws {
var info = SdModelInfo()
info.name = url.lastPathComponent
let attention = getModelAttention(url)
info.attention = attention ?? "unknown"
modelInfos[info.name] = info
modelUrls[info.name] = url
modelStates[info.name] = try ModelState(url: url)
}
private func getModelAttention(_ url: URL) -> String? {
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
struct ModelMetadata: Decodable {
let mlProgramOperationTypeHistogram: [String: Int]
}
do {
let jsonData = try Data(contentsOf: unetMetadataURL)
let metadatas = try JSONDecoder().decode([ModelMetadata].self, from: jsonData)
guard metadatas.count == 1 else {
return nil
}
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original"
} catch {
return nil
}
}
}

View File

@ -0,0 +1,31 @@
import Foundation
import GRPC
import StableDiffusionProtos
class ModelServiceProvider: SdModelServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func listModels(request: SdListModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
let models = await modelManager.listModels()
var response = SdListModelsResponse()
response.models.append(contentsOf: models)
return response
}
func reloadModels(request: SdReloadModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse {
try await modelManager.reloadModels()
return SdReloadModelsResponse()
}
func loadModel(request: SdLoadModelRequest, context: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdServerError.modelNotFound
}
try await state.load()
return SdLoadModelResponse()
}
}

View File

@ -0,0 +1,52 @@
import Foundation
import StableDiffusionProtos
import StableDiffusion
import CoreML
actor ModelState {
private let url: URL
private var pipeline: StableDiffusionPipeline? = nil
private var tokenizer: BPETokenizer? = nil
init(url: URL) throws {
self.url = url
}
func load() throws {
let config = MLModelConfiguration()
config.computeUnits = .all
pipeline = try StableDiffusionPipeline(
resourcesAt: url,
controlNet: [],
configuration: config,
disableSafety: true,
reduceMemory: false
)
let mergesUrl = url.appending(component: "merges.txt")
let vocabUrl = url.appending(component: "vocab.json")
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
}
func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
guard let pipeline else {
throw SdServerError.modelNotLoaded
}
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.imageCount {
let images = try pipeline.generateImages(configuration: pipelineConfig)
for cgImage in images {
guard let cgImage else { continue }
var image = SdImage()
image.content = try cgImage.toPngData()
response.images.append(image)
}
}
return response
}
}

View File

@ -0,0 +1,40 @@
import Foundation
import ArgumentParser
import GRPC
import NIO
import System
struct ServerCommand: ParsableCommand {
@Option(name: .shortAndLong, help: "Path to models directory")
var modelsDirectoryPath: String = "models"
mutating func run() throws {
let modelsDirectoryURL = URL(filePath: modelsDirectoryPath)
let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL)
let semaphore = DispatchSemaphore(value: 0)
Task {
print("Loading initial models...")
do {
try await modelManager.reloadModels()
} catch {
ServerCommand.exit(withError: error)
}
print("Loaded initial models.")
semaphore.signal()
}
semaphore.wait()
let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
_ = Server.insecure(group: group)
.withServiceProviders([
ModelServiceProvider(modelManager: modelManager),
ImageGenerationServiceProvider(modelManager: modelManager)
])
.bind(host: "0.0.0.0", port: 4546)
dispatchMain()
}
}
ServerCommand.main()