mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 05:51:32 +00:00
Initial Commit
This commit is contained in:
7
Sources/StableDiffusionServer/Errors.swift
Normal file
7
Sources/StableDiffusionServer/Errors.swift
Normal file
@ -0,0 +1,7 @@
|
||||
import Foundation
|
||||
|
||||
enum SdServerError: Error {
|
||||
case modelNotLoaded
|
||||
case imageEncode
|
||||
case modelNotFound
|
||||
}
|
22
Sources/StableDiffusionServer/ImageExtensions.swift
Normal file
22
Sources/StableDiffusionServer/ImageExtensions.swift
Normal file
@ -0,0 +1,22 @@
|
||||
import Foundation
|
||||
import CoreImage
|
||||
import UniformTypeIdentifiers
|
||||
|
||||
extension CGImage {
|
||||
func toPngData() throws -> Data {
|
||||
guard let data = CFDataCreateMutable(nil, 0) else {
|
||||
throw SdServerError.imageEncode
|
||||
}
|
||||
|
||||
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
|
||||
throw SdServerError.imageEncode
|
||||
}
|
||||
|
||||
CGImageDestinationAddImage(destination, self, nil)
|
||||
if CGImageDestinationFinalize(destination) {
|
||||
return data as Data
|
||||
} else {
|
||||
throw SdServerError.imageEncode
|
||||
}
|
||||
}
|
||||
}
|
18
Sources/StableDiffusionServer/ImageGenerationService.swift
Normal file
18
Sources/StableDiffusionServer/ImageGenerationService.swift
Normal file
@ -0,0 +1,18 @@
|
||||
import Foundation
|
||||
import GRPC
|
||||
import StableDiffusionProtos
|
||||
|
||||
class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
||||
private let modelManager: ModelManager
|
||||
|
||||
init(modelManager: ModelManager) {
|
||||
self.modelManager = modelManager
|
||||
}
|
||||
|
||||
func generateImage(request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdServerError.modelNotFound
|
||||
}
|
||||
return try await state.generate(request)
|
||||
}
|
||||
}
|
66
Sources/StableDiffusionServer/ModelManager.swift
Normal file
66
Sources/StableDiffusionServer/ModelManager.swift
Normal file
@ -0,0 +1,66 @@
|
||||
import Foundation
|
||||
import StableDiffusion
|
||||
import StableDiffusionProtos
|
||||
|
||||
actor ModelManager {
|
||||
private var modelInfos: [String : SdModelInfo] = [:]
|
||||
private var modelUrls: [String : URL] = [:]
|
||||
private var modelStates: [String : ModelState] = [:]
|
||||
|
||||
private let modelBaseURL: URL
|
||||
|
||||
public init(modelBaseURL: URL) {
|
||||
self.modelBaseURL = modelBaseURL
|
||||
}
|
||||
|
||||
func reloadModels() throws {
|
||||
modelInfos.removeAll()
|
||||
modelStates.removeAll()
|
||||
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
|
||||
for subdirectoryURL in contents {
|
||||
let values = try subdirectoryURL.resourceValues(forKeys: [.isDirectoryKey])
|
||||
if values.isDirectory ?? false {
|
||||
try addModel(url: subdirectoryURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func listModels() -> [SdModelInfo] {
|
||||
return Array(modelInfos.values)
|
||||
}
|
||||
|
||||
func getModelState(name: String) -> ModelState? {
|
||||
return modelStates[name]
|
||||
}
|
||||
|
||||
private func addModel(url: URL) throws {
|
||||
var info = SdModelInfo()
|
||||
info.name = url.lastPathComponent
|
||||
let attention = getModelAttention(url)
|
||||
info.attention = attention ?? "unknown"
|
||||
modelInfos[info.name] = info
|
||||
modelUrls[info.name] = url
|
||||
modelStates[info.name] = try ModelState(url: url)
|
||||
}
|
||||
|
||||
private func getModelAttention(_ url: URL) -> String? {
|
||||
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
||||
|
||||
struct ModelMetadata: Decodable {
|
||||
let mlProgramOperationTypeHistogram: [String: Int]
|
||||
}
|
||||
|
||||
do {
|
||||
let jsonData = try Data(contentsOf: unetMetadataURL)
|
||||
let metadatas = try JSONDecoder().decode([ModelMetadata].self, from: jsonData)
|
||||
|
||||
guard metadatas.count == 1 else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original"
|
||||
} catch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
31
Sources/StableDiffusionServer/ModelService.swift
Normal file
31
Sources/StableDiffusionServer/ModelService.swift
Normal file
@ -0,0 +1,31 @@
|
||||
import Foundation
|
||||
import GRPC
|
||||
import StableDiffusionProtos
|
||||
|
||||
class ModelServiceProvider: SdModelServiceAsyncProvider {
|
||||
private let modelManager: ModelManager
|
||||
|
||||
init(modelManager: ModelManager) {
|
||||
self.modelManager = modelManager
|
||||
}
|
||||
|
||||
func listModels(request: SdListModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
|
||||
let models = await modelManager.listModels()
|
||||
var response = SdListModelsResponse()
|
||||
response.models.append(contentsOf: models)
|
||||
return response
|
||||
}
|
||||
|
||||
func reloadModels(request: SdReloadModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse {
|
||||
try await modelManager.reloadModels()
|
||||
return SdReloadModelsResponse()
|
||||
}
|
||||
|
||||
func loadModel(request: SdLoadModelRequest, context: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdServerError.modelNotFound
|
||||
}
|
||||
try await state.load()
|
||||
return SdLoadModelResponse()
|
||||
}
|
||||
}
|
52
Sources/StableDiffusionServer/ModelState.swift
Normal file
52
Sources/StableDiffusionServer/ModelState.swift
Normal file
@ -0,0 +1,52 @@
|
||||
import Foundation
|
||||
import StableDiffusionProtos
|
||||
import StableDiffusion
|
||||
import CoreML
|
||||
|
||||
actor ModelState {
|
||||
private let url: URL
|
||||
private var pipeline: StableDiffusionPipeline? = nil
|
||||
private var tokenizer: BPETokenizer? = nil
|
||||
|
||||
init(url: URL) throws {
|
||||
self.url = url
|
||||
}
|
||||
|
||||
func load() throws {
|
||||
let config = MLModelConfiguration()
|
||||
config.computeUnits = .all
|
||||
pipeline = try StableDiffusionPipeline(
|
||||
resourcesAt: url,
|
||||
controlNet: [],
|
||||
configuration: config,
|
||||
disableSafety: true,
|
||||
reduceMemory: false
|
||||
)
|
||||
let mergesUrl = url.appending(component: "merges.txt")
|
||||
let vocabUrl = url.appending(component: "vocab.json")
|
||||
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
||||
}
|
||||
|
||||
func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||
guard let pipeline else {
|
||||
throw SdServerError.modelNotLoaded
|
||||
}
|
||||
|
||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||
pipelineConfig.negativePrompt = request.negativePrompt
|
||||
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.imageCount {
|
||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||
|
||||
for cgImage in images {
|
||||
guard let cgImage else { continue }
|
||||
var image = SdImage()
|
||||
image.content = try cgImage.toPngData()
|
||||
response.images.append(image)
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
}
|
40
Sources/StableDiffusionServer/main.swift
Normal file
40
Sources/StableDiffusionServer/main.swift
Normal file
@ -0,0 +1,40 @@
|
||||
import Foundation
|
||||
import ArgumentParser
|
||||
import GRPC
|
||||
import NIO
|
||||
import System
|
||||
|
||||
struct ServerCommand: ParsableCommand {
|
||||
@Option(name: .shortAndLong, help: "Path to models directory")
|
||||
var modelsDirectoryPath: String = "models"
|
||||
|
||||
mutating func run() throws {
|
||||
let modelsDirectoryURL = URL(filePath: modelsDirectoryPath)
|
||||
let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL)
|
||||
|
||||
let semaphore = DispatchSemaphore(value: 0)
|
||||
Task {
|
||||
print("Loading initial models...")
|
||||
do {
|
||||
try await modelManager.reloadModels()
|
||||
} catch {
|
||||
ServerCommand.exit(withError: error)
|
||||
}
|
||||
print("Loaded initial models.")
|
||||
semaphore.signal()
|
||||
}
|
||||
semaphore.wait()
|
||||
|
||||
let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
|
||||
_ = Server.insecure(group: group)
|
||||
.withServiceProviders([
|
||||
ModelServiceProvider(modelManager: modelManager),
|
||||
ImageGenerationServiceProvider(modelManager: modelManager)
|
||||
])
|
||||
.bind(host: "0.0.0.0", port: 4546)
|
||||
|
||||
dispatchMain()
|
||||
}
|
||||
|
||||
}
|
||||
ServerCommand.main()
|
Reference in New Issue
Block a user