commit 2759c8d7fbb282438ec8505264971fc01fdde70b Author: Alex Zenla Date: Sat Apr 22 14:52:27 2023 -0700 Initial Commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a43e767 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +.DS_Store +/.build +/Packages +/*.xcodeproj +xcuserdata/ +DerivedData/ +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.swiftpm/xcode/xcshareddata/xcschemes/*.xcscheme +/models +/.vscode +/Package.resolved diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto new file mode 100644 index 0000000..c443f81 --- /dev/null +++ b/Common/StableDiffusion.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; +package gay.pizza.stable.diffusion; + +option swift_prefix = "Sd"; + +message ModelInfo { + string name = 1; + string attention = 2; +} + +message Image { + bytes content = 1; +} + +message ListModelsRequest {} +message ListModelsResponse { + repeated ModelInfo models = 1; +} + +message ReloadModelsRequest {} +message ReloadModelsResponse {} + +enum Scheduler { + pndm = 0; + dpmSolverPlusPlus = 1; +} + +enum ComputeUnits { + cpu = 0; + cpu_and_gpu = 1; + all = 2; + cpu_and_neural_engine = 3; +} + +message LoadModelRequest { + string model_name = 1; + ComputeUnits compute_units = 2; + bool reduce_memory = 3; +} + +message LoadModelResponse {} + +service ModelService { + rpc ListModels(ListModelsRequest) returns (ListModelsResponse); + rpc ReloadModels(ReloadModelsRequest) returns (ReloadModelsResponse); + rpc LoadModel(LoadModelRequest) returns (LoadModelResponse); +} + +message GenerateImagesRequest { + string model_name = 1; + uint32 image_count = 2; + string prompt = 3; + string negative_prompt = 4; +} + +message GenerateImagesResponse { + repeated Image images = 1; +} + +service ImageGenerationService { + rpc GenerateImage(GenerateImagesRequest) returns (GenerateImagesResponse); +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e906025 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Gay Pizza Specifications + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Package.swift b/Package.swift new file mode 100644 index 0000000..d5b8c44 --- /dev/null +++ b/Package.swift @@ -0,0 +1,35 @@ +// swift-tools-version: 5.7 +import PackageDescription + +let package = Package( + name: "stable-diffusion-rpc", + platforms: [.macOS("13.1"), .iOS("16.2")], + products: [ + .executable(name: "StableDiffusionServer", targets: ["StableDiffusionServer"]), + .library(name: "StableDiffusionProtos", targets: ["StableDiffusionProtos"]), + .executable(name: "TestStableDiffusionClient", targets: ["TestStableDiffusionClient"]) + ], + dependencies: [ + .package(url: "https://github.com/apple/ml-stable-diffusion", revision: "5d2744e38297b01662b8bdfb41e899ac98036d8b"), + .package(url: "https://github.com/apple/swift-protobuf", from: "1.6.0"), + .package(url: "https://github.com/grpc/grpc-swift.git", from: "1.15.0"), + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0") + ], + targets: [ + .executableTarget(name: "StableDiffusionServer", dependencies: [ + .product(name: "StableDiffusion", package: "ml-stable-diffusion"), + .product(name: "SwiftProtobuf", package: "swift-protobuf"), + .product(name: "GRPC", package: "grpc-swift"), + .target(name: "StableDiffusionProtos"), + .product(name: "ArgumentParser", package: "swift-argument-parser") + ]), + .target(name: "StableDiffusionProtos", dependencies: [ + .product(name: "SwiftProtobuf", package: "swift-protobuf"), + .product(name: "GRPC", package: "grpc-swift") + ]), + .executableTarget(name: "TestStableDiffusionClient", dependencies: [ + .target(name: "StableDiffusionProtos"), + .product(name: "GRPC", package: "grpc-swift") + ]) + ] +) diff --git a/README.md b/README.md new file mode 100644 index 0000000..48488f5 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Stable Diffusion RPC + +A gRPC server for a Stable Diffusion worker on Apple Platforms. diff --git a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift new file mode 100644 index 0000000..10880c5 --- /dev/null +++ b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift @@ -0,0 +1,814 @@ +// +// DO NOT EDIT. +// +// Generated by the protocol buffer compiler. +// Source: StableDiffusion.proto +// + +// +// Copyright 2018, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +import GRPC +import NIO +import NIOConcurrencyHelpers +import SwiftProtobuf + + +/// Usage: instantiate `SdModelServiceClient`, then call methods of this protocol to make API calls. +public protocol SdModelServiceClientProtocol: GRPCClient { + var serviceName: String { get } + var interceptors: SdModelServiceClientInterceptorFactoryProtocol? { get } + + func listModels( + _ request: SdListModelsRequest, + callOptions: CallOptions? + ) -> UnaryCall + + func reloadModels( + _ request: SdReloadModelsRequest, + callOptions: CallOptions? + ) -> UnaryCall + + func loadModel( + _ request: SdLoadModelRequest, + callOptions: CallOptions? + ) -> UnaryCall +} + +extension SdModelServiceClientProtocol { + public var serviceName: String { + return "gay.pizza.stable.diffusion.ModelService" + } + + /// Unary call to ListModels + /// + /// - Parameters: + /// - request: Request to send to ListModels. + /// - callOptions: Call options. + /// - Returns: A `UnaryCall` with futures for the metadata, status and response. + public func listModels( + _ request: SdListModelsRequest, + callOptions: CallOptions? = nil + ) -> UnaryCall { + return self.makeUnaryCall( + path: SdModelServiceClientMetadata.Methods.listModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeListModelsInterceptors() ?? [] + ) + } + + /// Unary call to ReloadModels + /// + /// - Parameters: + /// - request: Request to send to ReloadModels. + /// - callOptions: Call options. + /// - Returns: A `UnaryCall` with futures for the metadata, status and response. + public func reloadModels( + _ request: SdReloadModelsRequest, + callOptions: CallOptions? = nil + ) -> UnaryCall { + return self.makeUnaryCall( + path: SdModelServiceClientMetadata.Methods.reloadModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] + ) + } + + /// Unary call to LoadModel + /// + /// - Parameters: + /// - request: Request to send to LoadModel. + /// - callOptions: Call options. + /// - Returns: A `UnaryCall` with futures for the metadata, status and response. + public func loadModel( + _ request: SdLoadModelRequest, + callOptions: CallOptions? = nil + ) -> UnaryCall { + return self.makeUnaryCall( + path: SdModelServiceClientMetadata.Methods.loadModel.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeLoadModelInterceptors() ?? [] + ) + } +} + +#if compiler(>=5.6) +@available(*, deprecated) +extension SdModelServiceClient: @unchecked Sendable {} +#endif // compiler(>=5.6) + +@available(*, deprecated, renamed: "SdModelServiceNIOClient") +public final class SdModelServiceClient: SdModelServiceClientProtocol { + private let lock = Lock() + private var _defaultCallOptions: CallOptions + private var _interceptors: SdModelServiceClientInterceptorFactoryProtocol? + public let channel: GRPCChannel + public var defaultCallOptions: CallOptions { + get { self.lock.withLock { return self._defaultCallOptions } } + set { self.lock.withLockVoid { self._defaultCallOptions = newValue } } + } + public var interceptors: SdModelServiceClientInterceptorFactoryProtocol? { + get { self.lock.withLock { return self._interceptors } } + set { self.lock.withLockVoid { self._interceptors = newValue } } + } + + /// Creates a client for the gay.pizza.stable.diffusion.ModelService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdModelServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self._defaultCallOptions = defaultCallOptions + self._interceptors = interceptors + } +} + +public struct SdModelServiceNIOClient: SdModelServiceClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdModelServiceClientInterceptorFactoryProtocol? + + /// Creates a client for the gay.pizza.stable.diffusion.ModelService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdModelServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#if compiler(>=5.6) +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdModelServiceAsyncClientProtocol: GRPCClient { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdModelServiceClientInterceptorFactoryProtocol? { get } + + func makeListModelsCall( + _ request: SdListModelsRequest, + callOptions: CallOptions? + ) -> GRPCAsyncUnaryCall + + func makeReloadModelsCall( + _ request: SdReloadModelsRequest, + callOptions: CallOptions? + ) -> GRPCAsyncUnaryCall + + func makeLoadModelCall( + _ request: SdLoadModelRequest, + callOptions: CallOptions? + ) -> GRPCAsyncUnaryCall +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdModelServiceAsyncClientProtocol { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdModelServiceClientMetadata.serviceDescriptor + } + + public var interceptors: SdModelServiceClientInterceptorFactoryProtocol? { + return nil + } + + public func makeListModelsCall( + _ request: SdListModelsRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncUnaryCall { + return self.makeAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.listModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeListModelsInterceptors() ?? [] + ) + } + + public func makeReloadModelsCall( + _ request: SdReloadModelsRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncUnaryCall { + return self.makeAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.reloadModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] + ) + } + + public func makeLoadModelCall( + _ request: SdLoadModelRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncUnaryCall { + return self.makeAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.loadModel.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeLoadModelInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdModelServiceAsyncClientProtocol { + public func listModels( + _ request: SdListModelsRequest, + callOptions: CallOptions? = nil + ) async throws -> SdListModelsResponse { + return try await self.performAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.listModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeListModelsInterceptors() ?? [] + ) + } + + public func reloadModels( + _ request: SdReloadModelsRequest, + callOptions: CallOptions? = nil + ) async throws -> SdReloadModelsResponse { + return try await self.performAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.reloadModels.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] + ) + } + + public func loadModel( + _ request: SdLoadModelRequest, + callOptions: CallOptions? = nil + ) async throws -> SdLoadModelResponse { + return try await self.performAsyncUnaryCall( + path: SdModelServiceClientMetadata.Methods.loadModel.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeLoadModelInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct SdModelServiceAsyncClient: SdModelServiceAsyncClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdModelServiceClientInterceptorFactoryProtocol? + + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdModelServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#endif // compiler(>=5.6) + +public protocol SdModelServiceClientInterceptorFactoryProtocol: GRPCSendable { + + /// - Returns: Interceptors to use when invoking 'listModels'. + func makeListModelsInterceptors() -> [ClientInterceptor] + + /// - Returns: Interceptors to use when invoking 'reloadModels'. + func makeReloadModelsInterceptors() -> [ClientInterceptor] + + /// - Returns: Interceptors to use when invoking 'loadModel'. + func makeLoadModelInterceptors() -> [ClientInterceptor] +} + +public enum SdModelServiceClientMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "ModelService", + fullName: "gay.pizza.stable.diffusion.ModelService", + methods: [ + SdModelServiceClientMetadata.Methods.listModels, + SdModelServiceClientMetadata.Methods.reloadModels, + SdModelServiceClientMetadata.Methods.loadModel, + ] + ) + + public enum Methods { + public static let listModels = GRPCMethodDescriptor( + name: "ListModels", + path: "/gay.pizza.stable.diffusion.ModelService/ListModels", + type: GRPCCallType.unary + ) + + public static let reloadModels = GRPCMethodDescriptor( + name: "ReloadModels", + path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels", + type: GRPCCallType.unary + ) + + public static let loadModel = GRPCMethodDescriptor( + name: "LoadModel", + path: "/gay.pizza.stable.diffusion.ModelService/LoadModel", + type: GRPCCallType.unary + ) + } +} + +/// Usage: instantiate `SdImageGenerationServiceClient`, then call methods of this protocol to make API calls. +public protocol SdImageGenerationServiceClientProtocol: GRPCClient { + var serviceName: String { get } + var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get } + + func generateImage( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? + ) -> UnaryCall +} + +extension SdImageGenerationServiceClientProtocol { + public var serviceName: String { + return "gay.pizza.stable.diffusion.ImageGenerationService" + } + + /// Unary call to GenerateImage + /// + /// - Parameters: + /// - request: Request to send to GenerateImage. + /// - callOptions: Call options. + /// - Returns: A `UnaryCall` with futures for the metadata, status and response. + public func generateImage( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil + ) -> UnaryCall { + return self.makeUnaryCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + ) + } +} + +#if compiler(>=5.6) +@available(*, deprecated) +extension SdImageGenerationServiceClient: @unchecked Sendable {} +#endif // compiler(>=5.6) + +@available(*, deprecated, renamed: "SdImageGenerationServiceNIOClient") +public final class SdImageGenerationServiceClient: SdImageGenerationServiceClientProtocol { + private let lock = Lock() + private var _defaultCallOptions: CallOptions + private var _interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? + public let channel: GRPCChannel + public var defaultCallOptions: CallOptions { + get { self.lock.withLock { return self._defaultCallOptions } } + set { self.lock.withLockVoid { self._defaultCallOptions = newValue } } + } + public var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { + get { self.lock.withLock { return self._interceptors } } + set { self.lock.withLockVoid { self._interceptors = newValue } } + } + + /// Creates a client for the gay.pizza.stable.diffusion.ImageGenerationService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self._defaultCallOptions = defaultCallOptions + self._interceptors = interceptors + } +} + +public struct SdImageGenerationServiceNIOClient: SdImageGenerationServiceClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? + + /// Creates a client for the gay.pizza.stable.diffusion.ImageGenerationService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#if compiler(>=5.6) +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get } + + func makeGenerateImageCall( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? + ) -> GRPCAsyncUnaryCall +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdImageGenerationServiceAsyncClientProtocol { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdImageGenerationServiceClientMetadata.serviceDescriptor + } + + public var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { + return nil + } + + public func makeGenerateImageCall( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncUnaryCall { + return self.makeAsyncUnaryCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdImageGenerationServiceAsyncClientProtocol { + public func generateImage( + _ request: SdGenerateImagesRequest, + callOptions: CallOptions? = nil + ) async throws -> SdGenerateImagesResponse { + return try await self.performAsyncUnaryCall( + path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct SdImageGenerationServiceAsyncClient: SdImageGenerationServiceAsyncClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? + + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#endif // compiler(>=5.6) + +public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSendable { + + /// - Returns: Interceptors to use when invoking 'generateImage'. + func makeGenerateImageInterceptors() -> [ClientInterceptor] +} + +public enum SdImageGenerationServiceClientMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "ImageGenerationService", + fullName: "gay.pizza.stable.diffusion.ImageGenerationService", + methods: [ + SdImageGenerationServiceClientMetadata.Methods.generateImage, + ] + ) + + public enum Methods { + public static let generateImage = GRPCMethodDescriptor( + name: "GenerateImage", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage", + type: GRPCCallType.unary + ) + } +} + +/// To build a server, implement a class that conforms to this protocol. +public protocol SdModelServiceProvider: CallHandlerProvider { + var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get } + + func listModels(request: SdListModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture + + func reloadModels(request: SdReloadModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture + + func loadModel(request: SdLoadModelRequest, context: StatusOnlyCallContext) -> EventLoopFuture +} + +extension SdModelServiceProvider { + public var serviceName: Substring { + return SdModelServiceServerMetadata.serviceDescriptor.fullName[...] + } + + /// Determines, calls and returns the appropriate request handler, depending on the request's method. + /// Returns nil for methods not handled by this service. + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "ListModels": + return UnaryServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeListModelsInterceptors() ?? [], + userFunction: self.listModels(request:context:) + ) + + case "ReloadModels": + return UnaryServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [], + userFunction: self.reloadModels(request:context:) + ) + + case "LoadModel": + return UnaryServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeLoadModelInterceptors() ?? [], + userFunction: self.loadModel(request:context:) + ) + + default: + return nil + } + } +} + +#if compiler(>=5.6) + +/// To implement a server, implement an object which conforms to this protocol. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdModelServiceAsyncProvider: CallHandlerProvider { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get } + + @Sendable func listModels( + request: SdListModelsRequest, + context: GRPCAsyncServerCallContext + ) async throws -> SdListModelsResponse + + @Sendable func reloadModels( + request: SdReloadModelsRequest, + context: GRPCAsyncServerCallContext + ) async throws -> SdReloadModelsResponse + + @Sendable func loadModel( + request: SdLoadModelRequest, + context: GRPCAsyncServerCallContext + ) async throws -> SdLoadModelResponse +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdModelServiceAsyncProvider { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdModelServiceServerMetadata.serviceDescriptor + } + + public var serviceName: Substring { + return SdModelServiceServerMetadata.serviceDescriptor.fullName[...] + } + + public var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { + return nil + } + + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "ListModels": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeListModelsInterceptors() ?? [], + wrapping: self.listModels(request:context:) + ) + + case "ReloadModels": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [], + wrapping: self.reloadModels(request:context:) + ) + + case "LoadModel": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeLoadModelInterceptors() ?? [], + wrapping: self.loadModel(request:context:) + ) + + default: + return nil + } + } +} + +#endif // compiler(>=5.6) + +public protocol SdModelServiceServerInterceptorFactoryProtocol { + + /// - Returns: Interceptors to use when handling 'listModels'. + /// Defaults to calling `self.makeInterceptors()`. + func makeListModelsInterceptors() -> [ServerInterceptor] + + /// - Returns: Interceptors to use when handling 'reloadModels'. + /// Defaults to calling `self.makeInterceptors()`. + func makeReloadModelsInterceptors() -> [ServerInterceptor] + + /// - Returns: Interceptors to use when handling 'loadModel'. + /// Defaults to calling `self.makeInterceptors()`. + func makeLoadModelInterceptors() -> [ServerInterceptor] +} + +public enum SdModelServiceServerMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "ModelService", + fullName: "gay.pizza.stable.diffusion.ModelService", + methods: [ + SdModelServiceServerMetadata.Methods.listModels, + SdModelServiceServerMetadata.Methods.reloadModels, + SdModelServiceServerMetadata.Methods.loadModel, + ] + ) + + public enum Methods { + public static let listModels = GRPCMethodDescriptor( + name: "ListModels", + path: "/gay.pizza.stable.diffusion.ModelService/ListModels", + type: GRPCCallType.unary + ) + + public static let reloadModels = GRPCMethodDescriptor( + name: "ReloadModels", + path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels", + type: GRPCCallType.unary + ) + + public static let loadModel = GRPCMethodDescriptor( + name: "LoadModel", + path: "/gay.pizza.stable.diffusion.ModelService/LoadModel", + type: GRPCCallType.unary + ) + } +} +/// To build a server, implement a class that conforms to this protocol. +public protocol SdImageGenerationServiceProvider: CallHandlerProvider { + var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get } + + func generateImage(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture +} + +extension SdImageGenerationServiceProvider { + public var serviceName: Substring { + return SdImageGenerationServiceServerMetadata.serviceDescriptor.fullName[...] + } + + /// Determines, calls and returns the appropriate request handler, depending on the request's method. + /// Returns nil for methods not handled by this service. + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "GenerateImage": + return UnaryServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [], + userFunction: self.generateImage(request:context:) + ) + + default: + return nil + } + } +} + +#if compiler(>=5.6) + +/// To implement a server, implement an object which conforms to this protocol. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get } + + @Sendable func generateImage( + request: SdGenerateImagesRequest, + context: GRPCAsyncServerCallContext + ) async throws -> SdGenerateImagesResponse +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdImageGenerationServiceAsyncProvider { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdImageGenerationServiceServerMetadata.serviceDescriptor + } + + public var serviceName: Substring { + return SdImageGenerationServiceServerMetadata.serviceDescriptor.fullName[...] + } + + public var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { + return nil + } + + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "GenerateImage": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [], + wrapping: self.generateImage(request:context:) + ) + + default: + return nil + } + } +} + +#endif // compiler(>=5.6) + +public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol { + + /// - Returns: Interceptors to use when handling 'generateImage'. + /// Defaults to calling `self.makeInterceptors()`. + func makeGenerateImageInterceptors() -> [ServerInterceptor] +} + +public enum SdImageGenerationServiceServerMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "ImageGenerationService", + fullName: "gay.pizza.stable.diffusion.ImageGenerationService", + methods: [ + SdImageGenerationServiceServerMetadata.Methods.generateImage, + ] + ) + + public enum Methods { + public static let generateImage = GRPCMethodDescriptor( + name: "GenerateImage", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage", + type: GRPCCallType.unary + ) + } +} diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift new file mode 100644 index 0000000..f56ec7b --- /dev/null +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -0,0 +1,572 @@ +// DO NOT EDIT. +// swift-format-ignore-file +// +// Generated by the Swift generator plugin for the protocol buffer compiler. +// Source: StableDiffusion.proto +// +// For information on using the generated types, please see the documentation: +// https://github.com/apple/swift-protobuf/ + +import Foundation +import SwiftProtobuf + +// If the compiler emits an error on this type, it is because this file +// was generated by a version of the `protoc` Swift plug-in that is +// incompatible with the version of SwiftProtobuf to which you are linking. +// Please ensure that you are building against the same version of the API +// that was used to generate this file. +fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck { + struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {} + typealias Version = _2 +} + +public enum SdScheduler: SwiftProtobuf.Enum { + public typealias RawValue = Int + case pndm // = 0 + case dpmSolverPlusPlus // = 1 + case UNRECOGNIZED(Int) + + public init() { + self = .pndm + } + + public init?(rawValue: Int) { + switch rawValue { + case 0: self = .pndm + case 1: self = .dpmSolverPlusPlus + default: self = .UNRECOGNIZED(rawValue) + } + } + + public var rawValue: Int { + switch self { + case .pndm: return 0 + case .dpmSolverPlusPlus: return 1 + case .UNRECOGNIZED(let i): return i + } + } + +} + +#if swift(>=4.2) + +extension SdScheduler: CaseIterable { + // The compiler won't synthesize support with the UNRECOGNIZED case. + public static var allCases: [SdScheduler] = [ + .pndm, + .dpmSolverPlusPlus, + ] +} + +#endif // swift(>=4.2) + +public enum SdComputeUnits: SwiftProtobuf.Enum { + public typealias RawValue = Int + case cpu // = 0 + case cpuAndGpu // = 1 + case all // = 2 + case cpuAndNeuralEngine // = 3 + case UNRECOGNIZED(Int) + + public init() { + self = .cpu + } + + public init?(rawValue: Int) { + switch rawValue { + case 0: self = .cpu + case 1: self = .cpuAndGpu + case 2: self = .all + case 3: self = .cpuAndNeuralEngine + default: self = .UNRECOGNIZED(rawValue) + } + } + + public var rawValue: Int { + switch self { + case .cpu: return 0 + case .cpuAndGpu: return 1 + case .all: return 2 + case .cpuAndNeuralEngine: return 3 + case .UNRECOGNIZED(let i): return i + } + } + +} + +#if swift(>=4.2) + +extension SdComputeUnits: CaseIterable { + // The compiler won't synthesize support with the UNRECOGNIZED case. + public static var allCases: [SdComputeUnits] = [ + .cpu, + .cpuAndGpu, + .all, + .cpuAndNeuralEngine, + ] +} + +#endif // swift(>=4.2) + +public struct SdModelInfo { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var name: String = String() + + public var attention: String = String() + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdImage { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var content: Data = Data() + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdListModelsRequest { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdListModelsResponse { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var models: [SdModelInfo] = [] + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdReloadModelsRequest { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdReloadModelsResponse { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdLoadModelRequest { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var modelName: String = String() + + public var computeUnits: SdComputeUnits = .cpu + + public var reduceMemory: Bool = false + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdLoadModelResponse { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdGenerateImagesRequest { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var modelName: String = String() + + public var imageCount: UInt32 = 0 + + public var prompt: String = String() + + public var negativePrompt: String = String() + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +public struct SdGenerateImagesResponse { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + public var images: [SdImage] = [] + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +#if swift(>=5.5) && canImport(_Concurrency) +extension SdScheduler: @unchecked Sendable {} +extension SdComputeUnits: @unchecked Sendable {} +extension SdModelInfo: @unchecked Sendable {} +extension SdImage: @unchecked Sendable {} +extension SdListModelsRequest: @unchecked Sendable {} +extension SdListModelsResponse: @unchecked Sendable {} +extension SdReloadModelsRequest: @unchecked Sendable {} +extension SdReloadModelsResponse: @unchecked Sendable {} +extension SdLoadModelRequest: @unchecked Sendable {} +extension SdLoadModelResponse: @unchecked Sendable {} +extension SdGenerateImagesRequest: @unchecked Sendable {} +extension SdGenerateImagesResponse: @unchecked Sendable {} +#endif // swift(>=5.5) && canImport(_Concurrency) + +// MARK: - Code below here is support for the SwiftProtobuf runtime. + +fileprivate let _protobuf_package = "gay.pizza.stable.diffusion" + +extension SdScheduler: SwiftProtobuf._ProtoNameProviding { + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 0: .same(proto: "pndm"), + 1: .same(proto: "dpmSolverPlusPlus"), + ] +} + +extension SdComputeUnits: SwiftProtobuf._ProtoNameProviding { + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 0: .same(proto: "cpu"), + 1: .same(proto: "cpu_and_gpu"), + 2: .same(proto: "all"), + 3: .same(proto: "cpu_and_neural_engine"), + ] +} + +extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".ModelInfo" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "name"), + 2: .same(proto: "attention"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularStringField(value: &self.name) }() + case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.name.isEmpty { + try visitor.visitSingularStringField(value: self.name, fieldNumber: 1) + } + if !self.attention.isEmpty { + try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool { + if lhs.name != rhs.name {return false} + if lhs.attention != rhs.attention {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".Image" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "content"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularBytesField(value: &self.content) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.content.isEmpty { + try visitor.visitSingularBytesField(value: self.content, fieldNumber: 1) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdImage, rhs: SdImage) -> Bool { + if lhs.content != rhs.content {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdListModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".ListModelsRequest" + public static let _protobuf_nameMap = SwiftProtobuf._NameMap() + + public mutating func decodeMessage(decoder: inout D) throws { + while let _ = try decoder.nextFieldNumber() { + } + } + + public func traverse(visitor: inout V) throws { + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdListModelsRequest, rhs: SdListModelsRequest) -> Bool { + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".ListModelsResponse" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "models"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeRepeatedMessageField(value: &self.models) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.models.isEmpty { + try visitor.visitRepeatedMessageField(value: self.models, fieldNumber: 1) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdListModelsResponse, rhs: SdListModelsResponse) -> Bool { + if lhs.models != rhs.models {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdReloadModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".ReloadModelsRequest" + public static let _protobuf_nameMap = SwiftProtobuf._NameMap() + + public mutating func decodeMessage(decoder: inout D) throws { + while let _ = try decoder.nextFieldNumber() { + } + } + + public func traverse(visitor: inout V) throws { + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdReloadModelsRequest, rhs: SdReloadModelsRequest) -> Bool { + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdReloadModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".ReloadModelsResponse" + public static let _protobuf_nameMap = SwiftProtobuf._NameMap() + + public mutating func decodeMessage(decoder: inout D) throws { + while let _ = try decoder.nextFieldNumber() { + } + } + + public func traverse(visitor: inout V) throws { + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdReloadModelsResponse, rhs: SdReloadModelsResponse) -> Bool { + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".LoadModelRequest" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "model_name"), + 2: .standard(proto: "compute_units"), + 3: .standard(proto: "reduce_memory"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }() + case 2: try { try decoder.decodeSingularEnumField(value: &self.computeUnits) }() + case 3: try { try decoder.decodeSingularBoolField(value: &self.reduceMemory) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.modelName.isEmpty { + try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1) + } + if self.computeUnits != .cpu { + try visitor.visitSingularEnumField(value: self.computeUnits, fieldNumber: 2) + } + if self.reduceMemory != false { + try visitor.visitSingularBoolField(value: self.reduceMemory, fieldNumber: 3) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdLoadModelRequest, rhs: SdLoadModelRequest) -> Bool { + if lhs.modelName != rhs.modelName {return false} + if lhs.computeUnits != rhs.computeUnits {return false} + if lhs.reduceMemory != rhs.reduceMemory {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdLoadModelResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".LoadModelResponse" + public static let _protobuf_nameMap = SwiftProtobuf._NameMap() + + public mutating func decodeMessage(decoder: inout D) throws { + while let _ = try decoder.nextFieldNumber() { + } + } + + public func traverse(visitor: inout V) throws { + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdLoadModelResponse, rhs: SdLoadModelResponse) -> Bool { + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".GenerateImagesRequest" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "model_name"), + 2: .standard(proto: "image_count"), + 3: .same(proto: "prompt"), + 4: .standard(proto: "negative_prompt"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.imageCount) }() + case 3: try { try decoder.decodeSingularStringField(value: &self.prompt) }() + case 4: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.modelName.isEmpty { + try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1) + } + if self.imageCount != 0 { + try visitor.visitSingularUInt32Field(value: self.imageCount, fieldNumber: 2) + } + if !self.prompt.isEmpty { + try visitor.visitSingularStringField(value: self.prompt, fieldNumber: 3) + } + if !self.negativePrompt.isEmpty { + try visitor.visitSingularStringField(value: self.negativePrompt, fieldNumber: 4) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdGenerateImagesRequest, rhs: SdGenerateImagesRequest) -> Bool { + if lhs.modelName != rhs.modelName {return false} + if lhs.imageCount != rhs.imageCount {return false} + if lhs.prompt != rhs.prompt {return false} + if lhs.negativePrompt != rhs.negativePrompt {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".GenerateImagesResponse" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "images"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.images.isEmpty { + try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdGenerateImagesResponse, rhs: SdGenerateImagesResponse) -> Bool { + if lhs.images != rhs.images {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} diff --git a/Sources/StableDiffusionServer/Errors.swift b/Sources/StableDiffusionServer/Errors.swift new file mode 100644 index 0000000..7fd124e --- /dev/null +++ b/Sources/StableDiffusionServer/Errors.swift @@ -0,0 +1,7 @@ +import Foundation + +enum SdServerError: Error { + case modelNotLoaded + case imageEncode + case modelNotFound +} diff --git a/Sources/StableDiffusionServer/ImageExtensions.swift b/Sources/StableDiffusionServer/ImageExtensions.swift new file mode 100644 index 0000000..fc9ea97 --- /dev/null +++ b/Sources/StableDiffusionServer/ImageExtensions.swift @@ -0,0 +1,22 @@ +import Foundation +import CoreImage +import UniformTypeIdentifiers + +extension CGImage { + func toPngData() throws -> Data { + guard let data = CFDataCreateMutable(nil, 0) else { + throw SdServerError.imageEncode + } + + guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else { + throw SdServerError.imageEncode + } + + CGImageDestinationAddImage(destination, self, nil) + if CGImageDestinationFinalize(destination) { + return data as Data + } else { + throw SdServerError.imageEncode + } + } +} diff --git a/Sources/StableDiffusionServer/ImageGenerationService.swift b/Sources/StableDiffusionServer/ImageGenerationService.swift new file mode 100644 index 0000000..6a02b5c --- /dev/null +++ b/Sources/StableDiffusionServer/ImageGenerationService.swift @@ -0,0 +1,18 @@ +import Foundation +import GRPC +import StableDiffusionProtos + +class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider { + private let modelManager: ModelManager + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + func generateImage(request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse { + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdServerError.modelNotFound + } + return try await state.generate(request) + } +} diff --git a/Sources/StableDiffusionServer/ModelManager.swift b/Sources/StableDiffusionServer/ModelManager.swift new file mode 100644 index 0000000..c6d1b37 --- /dev/null +++ b/Sources/StableDiffusionServer/ModelManager.swift @@ -0,0 +1,66 @@ +import Foundation +import StableDiffusion +import StableDiffusionProtos + +actor ModelManager { + private var modelInfos: [String : SdModelInfo] = [:] + private var modelUrls: [String : URL] = [:] + private var modelStates: [String : ModelState] = [:] + + private let modelBaseURL: URL + + public init(modelBaseURL: URL) { + self.modelBaseURL = modelBaseURL + } + + func reloadModels() throws { + modelInfos.removeAll() + modelStates.removeAll() + let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey]) + for subdirectoryURL in contents { + let values = try subdirectoryURL.resourceValues(forKeys: [.isDirectoryKey]) + if values.isDirectory ?? false { + try addModel(url: subdirectoryURL) + } + } + } + + func listModels() -> [SdModelInfo] { + return Array(modelInfos.values) + } + + func getModelState(name: String) -> ModelState? { + return modelStates[name] + } + + private func addModel(url: URL) throws { + var info = SdModelInfo() + info.name = url.lastPathComponent + let attention = getModelAttention(url) + info.attention = attention ?? "unknown" + modelInfos[info.name] = info + modelUrls[info.name] = url + modelStates[info.name] = try ModelState(url: url) + } + + private func getModelAttention(_ url: URL) -> String? { + let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json") + + struct ModelMetadata: Decodable { + let mlProgramOperationTypeHistogram: [String: Int] + } + + do { + let jsonData = try Data(contentsOf: unetMetadataURL) + let metadatas = try JSONDecoder().decode([ModelMetadata].self, from: jsonData) + + guard metadatas.count == 1 else { + return nil + } + + return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original" + } catch { + return nil + } + } +} diff --git a/Sources/StableDiffusionServer/ModelService.swift b/Sources/StableDiffusionServer/ModelService.swift new file mode 100644 index 0000000..140ca44 --- /dev/null +++ b/Sources/StableDiffusionServer/ModelService.swift @@ -0,0 +1,31 @@ +import Foundation +import GRPC +import StableDiffusionProtos + +class ModelServiceProvider: SdModelServiceAsyncProvider { + private let modelManager: ModelManager + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + func listModels(request: SdListModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse { + let models = await modelManager.listModels() + var response = SdListModelsResponse() + response.models.append(contentsOf: models) + return response + } + + func reloadModels(request: SdReloadModelsRequest, context: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse { + try await modelManager.reloadModels() + return SdReloadModelsResponse() + } + + func loadModel(request: SdLoadModelRequest, context: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse { + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdServerError.modelNotFound + } + try await state.load() + return SdLoadModelResponse() + } +} diff --git a/Sources/StableDiffusionServer/ModelState.swift b/Sources/StableDiffusionServer/ModelState.swift new file mode 100644 index 0000000..6d02421 --- /dev/null +++ b/Sources/StableDiffusionServer/ModelState.swift @@ -0,0 +1,52 @@ +import Foundation +import StableDiffusionProtos +import StableDiffusion +import CoreML + +actor ModelState { + private let url: URL + private var pipeline: StableDiffusionPipeline? = nil + private var tokenizer: BPETokenizer? = nil + + init(url: URL) throws { + self.url = url + } + + func load() throws { + let config = MLModelConfiguration() + config.computeUnits = .all + pipeline = try StableDiffusionPipeline( + resourcesAt: url, + controlNet: [], + configuration: config, + disableSafety: true, + reduceMemory: false + ) + let mergesUrl = url.appending(component: "merges.txt") + let vocabUrl = url.appending(component: "vocab.json") + tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl) + } + + func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse { + guard let pipeline else { + throw SdServerError.modelNotLoaded + } + + var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt) + pipelineConfig.negativePrompt = request.negativePrompt + pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max) + + var response = SdGenerateImagesResponse() + for _ in 0 ..< request.imageCount { + let images = try pipeline.generateImages(configuration: pipelineConfig) + + for cgImage in images { + guard let cgImage else { continue } + var image = SdImage() + image.content = try cgImage.toPngData() + response.images.append(image) + } + } + return response + } +} diff --git a/Sources/StableDiffusionServer/main.swift b/Sources/StableDiffusionServer/main.swift new file mode 100644 index 0000000..9c7c219 --- /dev/null +++ b/Sources/StableDiffusionServer/main.swift @@ -0,0 +1,40 @@ +import Foundation +import ArgumentParser +import GRPC +import NIO +import System + +struct ServerCommand: ParsableCommand { + @Option(name: .shortAndLong, help: "Path to models directory") + var modelsDirectoryPath: String = "models" + + mutating func run() throws { + let modelsDirectoryURL = URL(filePath: modelsDirectoryPath) + let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL) + + let semaphore = DispatchSemaphore(value: 0) + Task { + print("Loading initial models...") + do { + try await modelManager.reloadModels() + } catch { + ServerCommand.exit(withError: error) + } + print("Loaded initial models.") + semaphore.signal() + } + semaphore.wait() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) + _ = Server.insecure(group: group) + .withServiceProviders([ + ModelServiceProvider(modelManager: modelManager), + ImageGenerationServiceProvider(modelManager: modelManager) + ]) + .bind(host: "0.0.0.0", port: 4546) + + dispatchMain() + } + +} +ServerCommand.main() diff --git a/Sources/TestStableDiffusionClient/main.swift b/Sources/TestStableDiffusionClient/main.swift new file mode 100644 index 0000000..8093c0b --- /dev/null +++ b/Sources/TestStableDiffusionClient/main.swift @@ -0,0 +1,46 @@ +import Foundation +import StableDiffusionProtos +import NIO +import System +import GRPC + +let group = PlatformSupport.makeEventLoopGroup(loopCount: 1) +defer { + try? group.syncShutdownGracefully() +} + +let channel = try GRPCChannelPool.with( + target: .host("localhost", port: 4546), + transportSecurity: .plaintext, + eventLoopGroup: group +) + +let modelService = SdModelServiceAsyncClient(channel: channel) +let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel) + +Task { @MainActor in + do { + let modelListResponse = try await modelService.listModels(.init()) + print("Loading model...") + let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }! + _ = try await modelService.loadModel(.with { request in + request.modelName = modelInfo.name + }) + print("Loaded model.") + + print("Generating image...") + let request = SdGenerateImagesRequest.with { + $0.modelName = modelInfo.name + $0.prompt = "cat" + $0.imageCount = 1 + } + + let response = try await imageGeneratorService.generateImage(request) + print("Generated image.") + print(response) + } catch { + print(error) + exit(1) + } +} +dispatchMain() diff --git a/tools/generate-protobuf-swift.sh b/tools/generate-protobuf-swift.sh new file mode 100755 index 0000000..4c70499 --- /dev/null +++ b/tools/generate-protobuf-swift.sh @@ -0,0 +1,6 @@ +#!/bin/sh +set -e + +cd "$(dirname "${0}")/../Common" + +exec protoc --swift_opt=Visibility=Public --grpc-swift_opt=Visibility=Public --swift_out=../Sources/StableDiffusionProtos --grpc-swift_out=../Sources/StableDiffusionProtos StableDiffusion.proto