mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 05:30:54 +00:00
Formatting, linting, and hopefully a CI build.
This commit is contained in:
31
.github/workflows/macos.yml
vendored
Normal file
31
.github/workflows/macos.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: macOS
|
||||||
|
on: [push]
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: macos-12
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
- name: Build Executable
|
||||||
|
run: swift build -c release --arch arm64 --arch x86_64
|
||||||
|
- name: Copy Executable
|
||||||
|
run: cp .build/apple/Products/Release/StableDiffusionServer StableDiffusionServer
|
||||||
|
- name: Archive Executable
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
with:
|
||||||
|
name: StableDiffusionServer
|
||||||
|
path: StableDiffusionServer
|
||||||
|
format:
|
||||||
|
runs-on: macos-12
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
- name: Swift Format
|
||||||
|
run: swiftformat --lint Package.swift Sources
|
||||||
|
lint:
|
||||||
|
runs-on: macos-12
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
- name: Swift Lint
|
||||||
|
run: swiftlint Package.swift Sources
|
1
.swift-version
Normal file
1
.swift-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
5.6
|
4
.swiftformat
Normal file
4
.swiftformat
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
--indent 4
|
||||||
|
--disable trailingCommas
|
||||||
|
--exclude "Sources/StableDiffusionProtos/*.pb.swift"
|
||||||
|
--exclude "Sources/StableDiffusionProtos/*.grpc.swift"
|
4
.swiftlint.yml
Normal file
4
.swiftlint.yml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
line_length: 180
|
||||||
|
excluded:
|
||||||
|
- Sources/StableDiffusionProtos/*.pb.swift
|
||||||
|
- Sources/StableDiffusionProtos/*.grpc.swift
|
@ -16,17 +16,22 @@ let package = Package(
|
|||||||
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0")
|
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0")
|
||||||
],
|
],
|
||||||
targets: [
|
targets: [
|
||||||
|
.target(name: "StableDiffusionProtos", dependencies: [
|
||||||
|
.product(name: "SwiftProtobuf", package: "swift-protobuf"),
|
||||||
|
.product(name: "GRPC", package: "grpc-swift")
|
||||||
|
]),
|
||||||
|
.target(name: "StableDiffusionCore", dependencies: [
|
||||||
|
.product(name: "StableDiffusion", package: "ml-stable-diffusion"),
|
||||||
|
.target(name: "StableDiffusionProtos")
|
||||||
|
]),
|
||||||
.executableTarget(name: "StableDiffusionServer", dependencies: [
|
.executableTarget(name: "StableDiffusionServer", dependencies: [
|
||||||
.product(name: "StableDiffusion", package: "ml-stable-diffusion"),
|
.product(name: "StableDiffusion", package: "ml-stable-diffusion"),
|
||||||
.product(name: "SwiftProtobuf", package: "swift-protobuf"),
|
.product(name: "SwiftProtobuf", package: "swift-protobuf"),
|
||||||
.product(name: "GRPC", package: "grpc-swift"),
|
.product(name: "GRPC", package: "grpc-swift"),
|
||||||
.target(name: "StableDiffusionProtos"),
|
.target(name: "StableDiffusionProtos"),
|
||||||
|
.target(name: "StableDiffusionCore"),
|
||||||
.product(name: "ArgumentParser", package: "swift-argument-parser")
|
.product(name: "ArgumentParser", package: "swift-argument-parser")
|
||||||
]),
|
]),
|
||||||
.target(name: "StableDiffusionProtos", dependencies: [
|
|
||||||
.product(name: "SwiftProtobuf", package: "swift-protobuf"),
|
|
||||||
.product(name: "GRPC", package: "grpc-swift")
|
|
||||||
]),
|
|
||||||
.executableTarget(name: "TestStableDiffusionClient", dependencies: [
|
.executableTarget(name: "TestStableDiffusionClient", dependencies: [
|
||||||
.target(name: "StableDiffusionProtos"),
|
.target(name: "StableDiffusionProtos"),
|
||||||
.product(name: "GRPC", package: "grpc-swift")
|
.product(name: "GRPC", package: "grpc-swift")
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
|
|
||||||
enum SdServerError: Error {
|
public enum SdCoreError: Error {
|
||||||
case modelNotLoaded
|
case modelNotLoaded
|
||||||
case imageEncode
|
case imageEncode
|
||||||
case modelNotFound
|
case modelNotFound
|
@ -1,22 +1,22 @@
|
|||||||
import Foundation
|
|
||||||
import CoreImage
|
import CoreImage
|
||||||
|
import Foundation
|
||||||
import UniformTypeIdentifiers
|
import UniformTypeIdentifiers
|
||||||
|
|
||||||
extension CGImage {
|
extension CGImage {
|
||||||
func toPngData() throws -> Data {
|
func toPngData() throws -> Data {
|
||||||
guard let data = CFDataCreateMutable(nil, 0) else {
|
guard let data = CFDataCreateMutable(nil, 0) else {
|
||||||
throw SdServerError.imageEncode
|
throw SdCoreError.imageEncode
|
||||||
}
|
}
|
||||||
|
|
||||||
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
|
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
|
||||||
throw SdServerError.imageEncode
|
throw SdCoreError.imageEncode
|
||||||
}
|
}
|
||||||
|
|
||||||
CGImageDestinationAddImage(destination, self, nil)
|
CGImageDestinationAddImage(destination, self, nil)
|
||||||
if CGImageDestinationFinalize(destination) {
|
if CGImageDestinationFinalize(destination) {
|
||||||
return data as Data
|
return data as Data
|
||||||
} else {
|
} else {
|
||||||
throw SdServerError.imageEncode
|
throw SdCoreError.imageEncode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,18 +2,18 @@ import Foundation
|
|||||||
import StableDiffusion
|
import StableDiffusion
|
||||||
import StableDiffusionProtos
|
import StableDiffusionProtos
|
||||||
|
|
||||||
actor ModelManager {
|
public actor ModelManager {
|
||||||
private var modelInfos: [String : SdModelInfo] = [:]
|
private var modelInfos: [String: SdModelInfo] = [:]
|
||||||
private var modelUrls: [String : URL] = [:]
|
private var modelUrls: [String: URL] = [:]
|
||||||
private var modelStates: [String : ModelState] = [:]
|
private var modelStates: [String: ModelState] = [:]
|
||||||
|
|
||||||
private let modelBaseURL: URL
|
private let modelBaseURL: URL
|
||||||
|
|
||||||
public init(modelBaseURL: URL) {
|
public init(modelBaseURL: URL) {
|
||||||
self.modelBaseURL = modelBaseURL
|
self.modelBaseURL = modelBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadModels() throws {
|
public func reloadModels() throws {
|
||||||
modelInfos.removeAll()
|
modelInfos.removeAll()
|
||||||
modelStates.removeAll()
|
modelStates.removeAll()
|
||||||
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
|
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
|
||||||
@ -24,15 +24,15 @@ actor ModelManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listModels() -> [SdModelInfo] {
|
public func listModels() -> [SdModelInfo] {
|
||||||
return Array(modelInfos.values)
|
Array(modelInfos.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getModelState(name: String) -> ModelState? {
|
public func getModelState(name: String) -> ModelState? {
|
||||||
return modelStates[name]
|
modelStates[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
private func addModel(url: URL) throws {
|
private func addModel(url: URL) throws {
|
||||||
var info = SdModelInfo()
|
var info = SdModelInfo()
|
||||||
info.name = url.lastPathComponent
|
info.name = url.lastPathComponent
|
||||||
@ -42,10 +42,10 @@ actor ModelManager {
|
|||||||
modelUrls[info.name] = url
|
modelUrls[info.name] = url
|
||||||
modelStates[info.name] = try ModelState(url: url)
|
modelStates[info.name] = try ModelState(url: url)
|
||||||
}
|
}
|
||||||
|
|
||||||
private func getModelAttention(_ url: URL) -> String? {
|
private func getModelAttention(_ url: URL) -> String? {
|
||||||
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
||||||
|
|
||||||
struct ModelMetadata: Decodable {
|
struct ModelMetadata: Decodable {
|
||||||
let mlProgramOperationTypeHistogram: [String: Int]
|
let mlProgramOperationTypeHistogram: [String: Int]
|
||||||
}
|
}
|
@ -1,18 +1,18 @@
|
|||||||
import Foundation
|
|
||||||
import StableDiffusionProtos
|
|
||||||
import StableDiffusion
|
|
||||||
import CoreML
|
import CoreML
|
||||||
|
import Foundation
|
||||||
|
import StableDiffusion
|
||||||
|
import StableDiffusionProtos
|
||||||
|
|
||||||
actor ModelState {
|
public actor ModelState {
|
||||||
private let url: URL
|
private let url: URL
|
||||||
private var pipeline: StableDiffusionPipeline? = nil
|
private var pipeline: StableDiffusionPipeline?
|
||||||
private var tokenizer: BPETokenizer? = nil
|
private var tokenizer: BPETokenizer?
|
||||||
|
|
||||||
init(url: URL) throws {
|
public init(url: URL) throws {
|
||||||
self.url = url
|
self.url = url
|
||||||
}
|
}
|
||||||
|
|
||||||
func load() throws {
|
public func load() throws {
|
||||||
let config = MLModelConfiguration()
|
let config = MLModelConfiguration()
|
||||||
config.computeUnits = .all
|
config.computeUnits = .all
|
||||||
pipeline = try StableDiffusionPipeline(
|
pipeline = try StableDiffusionPipeline(
|
||||||
@ -26,20 +26,20 @@ actor ModelState {
|
|||||||
let vocabUrl = url.appending(component: "vocab.json")
|
let vocabUrl = url.appending(component: "vocab.json")
|
||||||
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||||
guard let pipeline else {
|
guard let pipeline else {
|
||||||
throw SdServerError.modelNotLoaded
|
throw SdCoreError.modelNotLoaded
|
||||||
}
|
}
|
||||||
|
|
||||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||||
pipelineConfig.negativePrompt = request.negativePrompt
|
pipelineConfig.negativePrompt = request.negativePrompt
|
||||||
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
|
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||||
|
|
||||||
var response = SdGenerateImagesResponse()
|
var response = SdGenerateImagesResponse()
|
||||||
for _ in 0 ..< request.imageCount {
|
for _ in 0 ..< request.imageCount {
|
||||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||||
|
|
||||||
for cgImage in images {
|
for cgImage in images {
|
||||||
guard let cgImage else { continue }
|
guard let cgImage else { continue }
|
||||||
var image = SdImage()
|
var image = SdImage()
|
@ -1,17 +1,18 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import GRPC
|
import GRPC
|
||||||
|
import StableDiffusionCore
|
||||||
import StableDiffusionProtos
|
import StableDiffusionProtos
|
||||||
|
|
||||||
class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
||||||
private let modelManager: ModelManager
|
private let modelManager: ModelManager
|
||||||
|
|
||||||
init(modelManager: ModelManager) {
|
init(modelManager: ModelManager) {
|
||||||
self.modelManager = modelManager
|
self.modelManager = modelManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateImage(request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
func generateImage(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
||||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||||
throw SdServerError.modelNotFound
|
throw SdCoreError.modelNotFound
|
||||||
}
|
}
|
||||||
return try await state.generate(request)
|
return try await state.generate(request)
|
||||||
}
|
}
|
||||||
|
@ -1,29 +1,30 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import GRPC
|
import GRPC
|
||||||
|
import StableDiffusionCore
|
||||||
import StableDiffusionProtos
|
import StableDiffusionProtos
|
||||||
|
|
||||||
class ModelServiceProvider: SdModelServiceAsyncProvider {
|
class ModelServiceProvider: SdModelServiceAsyncProvider {
|
||||||
private let modelManager: ModelManager
|
private let modelManager: ModelManager
|
||||||
|
|
||||||
init(modelManager: ModelManager) {
|
init(modelManager: ModelManager) {
|
||||||
self.modelManager = modelManager
|
self.modelManager = modelManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func listModels(request: SdListModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
|
func listModels(request _: SdListModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
|
||||||
let models = await modelManager.listModels()
|
let models = await modelManager.listModels()
|
||||||
var response = SdListModelsResponse()
|
var response = SdListModelsResponse()
|
||||||
response.models.append(contentsOf: models)
|
response.models.append(contentsOf: models)
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadModels(request: SdReloadModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse {
|
func reloadModels(request _: SdReloadModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse {
|
||||||
try await modelManager.reloadModels()
|
try await modelManager.reloadModels()
|
||||||
return SdReloadModelsResponse()
|
return SdReloadModelsResponse()
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadModel(request: SdLoadModelRequest, context: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
||||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||||
throw SdServerError.modelNotFound
|
throw SdCoreError.modelNotFound
|
||||||
}
|
}
|
||||||
try await state.load()
|
try await state.load()
|
||||||
return SdLoadModelResponse()
|
return SdLoadModelResponse()
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import Foundation
|
|
||||||
import ArgumentParser
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
import GRPC
|
import GRPC
|
||||||
import NIO
|
import NIO
|
||||||
|
import StableDiffusionCore
|
||||||
import System
|
import System
|
||||||
|
|
||||||
struct ServerCommand: ParsableCommand {
|
struct ServerCommand: ParsableCommand {
|
||||||
@Option(name: .shortAndLong, help: "Path to models directory")
|
@Option(name: .shortAndLong, help: "Path to models directory")
|
||||||
var modelsDirectoryPath: String = "models"
|
var modelsDirectoryPath: String = "models"
|
||||||
|
|
||||||
mutating func run() throws {
|
mutating func run() throws {
|
||||||
let modelsDirectoryURL = URL(filePath: modelsDirectoryPath)
|
let modelsDirectoryURL = URL(filePath: modelsDirectoryPath)
|
||||||
let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL)
|
let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL)
|
||||||
|
|
||||||
let semaphore = DispatchSemaphore(value: 0)
|
let semaphore = DispatchSemaphore(value: 0)
|
||||||
Task {
|
Task {
|
||||||
print("Loading initial models...")
|
print("Loading initial models...")
|
||||||
@ -32,9 +33,9 @@ struct ServerCommand: ParsableCommand {
|
|||||||
ImageGenerationServiceProvider(modelManager: modelManager)
|
ImageGenerationServiceProvider(modelManager: modelManager)
|
||||||
])
|
])
|
||||||
.bind(host: "0.0.0.0", port: 4546)
|
.bind(host: "0.0.0.0", port: 4546)
|
||||||
|
|
||||||
dispatchMain()
|
dispatchMain()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ServerCommand.main()
|
ServerCommand.main()
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import StableDiffusionProtos
|
|
||||||
import NIO
|
|
||||||
import System
|
|
||||||
import GRPC
|
import GRPC
|
||||||
|
import NIO
|
||||||
|
import StableDiffusionProtos
|
||||||
|
import System
|
||||||
|
|
||||||
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
|
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
|
||||||
defer {
|
defer {
|
||||||
try? group.syncShutdownGracefully()
|
try? group.syncShutdownGracefully()
|
||||||
}
|
}
|
||||||
|
|
||||||
let channel = try GRPCChannelPool.with(
|
let channel = try GRPCChannelPool.with(
|
||||||
target: .host("localhost", port: 4546),
|
target: .host("localhost", port: 4546),
|
||||||
transportSecurity: .plaintext,
|
transportSecurity: .plaintext,
|
||||||
eventLoopGroup: group
|
eventLoopGroup: group
|
||||||
)
|
)
|
||||||
|
|
||||||
let modelService = SdModelServiceAsyncClient(channel: channel)
|
let modelService = SdModelServiceAsyncClient(channel: channel)
|
||||||
@ -27,14 +27,14 @@ Task { @MainActor in
|
|||||||
request.modelName = modelInfo.name
|
request.modelName = modelInfo.name
|
||||||
})
|
})
|
||||||
print("Loaded model.")
|
print("Loaded model.")
|
||||||
|
|
||||||
print("Generating image...")
|
print("Generating image...")
|
||||||
let request = SdGenerateImagesRequest.with {
|
let request = SdGenerateImagesRequest.with {
|
||||||
$0.modelName = modelInfo.name
|
$0.modelName = modelInfo.name
|
||||||
$0.prompt = "cat"
|
$0.prompt = "cat"
|
||||||
$0.imageCount = 1
|
$0.imageCount = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = try await imageGeneratorService.generateImage(request)
|
let response = try await imageGeneratorService.generateImage(request)
|
||||||
print("Generated image.")
|
print("Generated image.")
|
||||||
print(response)
|
print(response)
|
||||||
@ -43,4 +43,5 @@ Task { @MainActor in
|
|||||||
exit(1)
|
exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatchMain()
|
dispatchMain()
|
||||||
|
Reference in New Issue
Block a user