diff --git a/Clients/Java/build.gradle.kts b/Clients/Java/build.gradle.kts index 7df1a35..efafde7 100644 --- a/Clients/Java/build.gradle.kts +++ b/Clients/Java/build.gradle.kts @@ -26,15 +26,6 @@ java { withSourcesJar() } -sourceSets { - main { - proto { - srcDir("../../Common") - include("*.proto") - } - } -} - dependencies { implementation("org.jetbrains.kotlin:kotlin-bom") implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8") diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index 30c5c42..844d8c5 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -1,10 +1,13 @@ package gay.pizza.stable.diffusion.sample +import gay.pizza.stable.diffusion.StableDiffusion import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest import gay.pizza.stable.diffusion.StableDiffusionRpcClient import io.grpc.ManagedChannelBuilder +import kotlin.io.path.Path +import kotlin.io.path.writeBytes import kotlin.system.exitProcess fun main() { @@ -15,20 +18,22 @@ fun main() { val client = StableDiffusionRpcClient(channel) val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance()) - if (modelListResponse.modelsList.isEmpty()) { + if (modelListResponse.availableModelsList.isEmpty()) { println("no available models") exitProcess(0) } println("available models:") - for (model in modelListResponse.modelsList) { - println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}") + for (model in modelListResponse.availableModelsList) { + val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else "" + println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}") } - val model = modelListResponse.modelsList.random() + val model = modelListResponse.availableModelsList.random() if (!model.isLoaded) { println("loading model ${model.name}...") client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply { modelName = model.name + computeUnits = model.supportedComputeUnitsList.first() }.build()) } else { println("using model ${model.name}...") @@ -36,14 +41,22 @@ fun main() { println("generating images...") - val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply { + val request = GenerateImagesRequest.newBuilder().apply { modelName = model.name - imageCount = 1 + outputImageFormat = StableDiffusion.ImageFormat.png + batchSize = 2 + batchCount = 2 prompt = "cat" negativePrompt = "bad, low quality, nsfw" - }.build()) + }.build() + val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request) - println("generated ${generateImagesResponse.imagesCount} images") + println("generated ${generateImagesResponse.imagesCount} images:") + for ((index, image) in generateImagesResponse.imagesList.withIndex()) { + println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)") + val path = Path("work/image${index}.${image.format.name}") + path.writeBytes(image.data.toByteArray()) + } channel.shutdownNow() } diff --git a/Clients/Java/src/main/proto b/Clients/Java/src/main/proto new file mode 120000 index 0000000..f332154 --- /dev/null +++ b/Clients/Java/src/main/proto @@ -0,0 +1 @@ +../../../../Common \ No newline at end of file diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index 258d9e9..0b0d10b 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -1,63 +1,247 @@ +/** + * Stable Diffusion RPC service for Apple Platforms. + */ syntax = "proto3"; package gay.pizza.stable.diffusion; +/** + * Utilize a prefix of 'Sd' for Swift. + */ option swift_prefix = "Sd"; -message ModelInfo { - string name = 1; - string attention = 2; - bool is_loaded = 3; +/** + * Represents the model attention. Model attention has to do with how the model is encoded, and + * can determine what compute units are able to support a particular model. + */ +enum ModelAttention { + /** + * The model is an original attention type. It can be loaded only onto CPU & GPU compute units. + */ + original = 0; + + /** + * The model is a split-ein-sum attention type. It can be loaded onto all compute units, + * including the Apple Neural Engine. + */ + split_ein_sum = 1; } -message Image { - bytes content = 1; -} - -message ListModelsRequest {} -message ListModelsResponse { - repeated ModelInfo models = 1; -} - -message ReloadModelsRequest {} -message ReloadModelsResponse {} - +/** + * Represents the schedulers that are used to sample images. + */ enum Scheduler { + /** + * The PNDM (Pseudo numerical methods for diffusion models) scheduler. + */ pndm = 0; - dpmSolverPlusPlus = 1; + + /** + * The DPM-Solver++ scheduler. + */ + dpm_solver_plus_plus = 1; } +/** + * Represents a specifier for what compute units are available for ML tasks. + */ enum ComputeUnits { + /** + * The CPU as a singular compute unit. + */ cpu = 0; + + /** + * The CPU & GPU combined into a singular compute unit. + */ cpu_and_gpu = 1; + + /** + * Allow the usage of all compute units. CoreML will decided where the model is loaded. + */ all = 2; + + /** + * The CPU & Neural Engine combined into a singular compute unit. + */ cpu_and_neural_engine = 3; } -message LoadModelRequest { - string model_name = 1; - ComputeUnits compute_units = 2; - bool reduce_memory = 3; +/** + * Represents information about an available model. + * The primary key of a model is it's 'name' field. + */ +message ModelInfo { + /** + * The name of the available model. Note that within the context of a single RPC server, + * the name of a model is a unique identifier. This may not be true when utilizing a cluster or + * load balanced server, so keep that in mind. + */ + string name = 1; + + /** + * The attention of the model. Model attention determines what compute units can be used to + * load the model and make predictions. + */ + ModelAttention attention = 2; + + /** + * Whether the model is currently loaded onto an available compute unit. + */ + bool is_loaded = 3; + + /** + * The compute unit that the model is currently loaded into, if it is loaded to one at all. + * When is_loaded is false, the value of this field should be null. + */ + ComputeUnits loaded_compute_units = 4; + + /** + * The compute units that this model supports using. + */ + repeated ComputeUnits supported_compute_units = 5; } +/** + * Represents the format of an image. + */ +enum ImageFormat { + /** + * The PNG image format. + */ + png = 0; +} + +/** + * Represents an image within the Stable Diffusion context. + * This could be an input image for an image generation request, or it could be + * a generated image from the Stable Diffusion model. + */ +message Image { + /** + * The format of the image. + */ + ImageFormat format = 1; + + /** + * The raw data of the image, in the specified format. + */ + bytes data = 2; +} + +/** + * Represents a request to list the models available on the host. + */ +message ListModelsRequest {} + +/** + * Represents a response to listing the models available on the host. + */ +message ListModelsResponse { + /** + * The available models on the Stable Diffusion server. + */ + repeated ModelInfo available_models = 1; +} + +/** + * Represents a request to load a model into a specified compute unit. + */ +message LoadModelRequest { + /** + * The model name to load onto the compute unit. + */ + string model_name = 1; + + /** + * The compute units to load the model onto. + */ + ComputeUnits compute_units = 2; +} + +/** + * Represents a response to loading a model. + */ message LoadModelResponse {} +/** + * The model service, for management and loading of models. + */ service ModelService { + /** + * Lists the available models on the host. + * This will return both models that are currently loaded, and models that are not yet loaded. + */ rpc ListModels(ListModelsRequest) returns (ListModelsResponse); - rpc ReloadModels(ReloadModelsRequest) returns (ReloadModelsResponse); + + /** + * Loads a model onto a compute unit. + */ rpc LoadModel(LoadModelRequest) returns (LoadModelResponse); } +/** + * Represents a request to generate images using a loaded model. + */ message GenerateImagesRequest { + /** + * The model name to use for generation. + * The model must be already be loaded using ModelService.LoadModel RPC method. + */ string model_name = 1; - uint32 image_count = 2; - string prompt = 3; - string negative_prompt = 4; + + /** + * The output format for generated images. + */ + ImageFormat output_image_format = 2; + + /** + * The number of batches of images to generate. + */ + uint32 batch_count = 3; + + /** + * The number of images inside a single batch. + */ + uint32 batch_size = 4; + + /** + * The positive textual prompt for image generation. + */ + string prompt = 5; + + /** + * The negative prompt for image generation. + */ + string negative_prompt = 6; + + /** + * The random seed to use. + * Zero indicates that the seed should be random. + */ + uint32 seed = 7; } +/** + * Represents the response from image generation. + */ message GenerateImagesResponse { + /** + * The set of generated images by the Stable Diffusion pipeline. + */ repeated Image images = 1; + + /** + * The seeds that were used to generate the images. + */ + repeated uint32 seeds = 2; } +/** + * The image generation service, for generating images from loaded models. + */ service ImageGenerationService { - rpc GenerateImage(GenerateImagesRequest) returns (GenerateImagesResponse); + /** + * Generates images using a loaded model. + */ + rpc GenerateImages(GenerateImagesRequest) returns (GenerateImagesResponse); } diff --git a/Sources/StableDiffusionControl/main.swift b/Sources/StableDiffusionControl/main.swift index 34c9098..eb8801c 100644 --- a/Sources/StableDiffusionControl/main.swift +++ b/Sources/StableDiffusionControl/main.swift @@ -9,23 +9,25 @@ let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port Task { @MainActor in do { let modelListResponse = try await client.modelService.listModels(.init()) - print("Loading model...") - let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }! + print("Loading random model...") + let modelInfo = modelListResponse.availableModels.randomElement()! _ = try await client.modelService.loadModel(.with { request in request.modelName = modelInfo.name }) - print("Loaded model.") + print("Loaded random model.") print("Generating image...") let request = SdGenerateImagesRequest.with { $0.modelName = modelInfo.name + $0.outputImageFormat = .png $0.prompt = "cat" - $0.imageCount = 1 + $0.batchCount = 1 + $0.batchSize = 1 } - let response = try await client.imageGenerationService.generateImage(request) + let response = try await client.imageGenerationService.generateImages(request) let image = response.images.first! - try image.content.write(to: URL(filePath: "output.png")) + try image.data.write(to: URL(filePath: "output.png")) print("Generated image to output.png") exit(0) } catch { diff --git a/Sources/StableDiffusionCore/Errors.swift b/Sources/StableDiffusionCore/Errors.swift index 5868f76..6043444 100644 --- a/Sources/StableDiffusionCore/Errors.swift +++ b/Sources/StableDiffusionCore/Errors.swift @@ -2,6 +2,6 @@ import Foundation public enum SdCoreError: Error { case modelNotLoaded - case imageEncode + case imageEncodeFailed case modelNotFound } diff --git a/Sources/StableDiffusionCore/ImageExtensions.swift b/Sources/StableDiffusionCore/ImageExtensions.swift index d242107..fa70e0d 100644 --- a/Sources/StableDiffusionCore/ImageExtensions.swift +++ b/Sources/StableDiffusionCore/ImageExtensions.swift @@ -1,22 +1,38 @@ import CoreImage import Foundation +import StableDiffusionProtos import UniformTypeIdentifiers extension CGImage { - func toPngData() throws -> Data { + func toImageData(format: SdImageFormat) throws -> Data { guard let data = CFDataCreateMutable(nil, 0) else { - throw SdCoreError.imageEncode + throw SdCoreError.imageEncodeFailed } - guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else { - throw SdCoreError.imageEncode + guard let destination = try CGImageDestinationCreateWithData(data, formatToTypeIdentifier(format) as CFString, 1, nil) else { + throw SdCoreError.imageEncodeFailed } CGImageDestinationAddImage(destination, self, nil) if CGImageDestinationFinalize(destination) { return data as Data } else { - throw SdCoreError.imageEncode + throw SdCoreError.imageEncodeFailed + } + } + + func toSdImage(format: SdImageFormat) throws -> SdImage { + let content = try toImageData(format: format) + var image = SdImage() + image.format = format + image.data = content + return image + } + + private func formatToTypeIdentifier(_ format: SdImageFormat) throws -> String { + switch format { + case .png: return "public.png" + default: throw SdCoreError.imageEncodeFailed } } } diff --git a/Sources/StableDiffusionCore/ModelManager.swift b/Sources/StableDiffusionCore/ModelManager.swift index cfdf020..a046be3 100644 --- a/Sources/StableDiffusionCore/ModelManager.swift +++ b/Sources/StableDiffusionCore/ModelManager.swift @@ -13,7 +13,7 @@ public actor ModelManager { self.modelBaseURL = modelBaseURL } - public func reloadModels() throws { + public func reloadAvailableModels() throws { modelInfos.removeAll() modelUrls.removeAll() modelStates.removeAll() @@ -26,8 +26,37 @@ public actor ModelManager { } } - public func listModels() -> [SdModelInfo] { - Array(modelInfos.values) + public func listAvailableModels() async throws -> [SdModelInfo] { + var results: [SdModelInfo] = [] + for simpleInfo in modelInfos.values { + var info = try SdModelInfo(jsonString: simpleInfo.jsonString()) + if let maybeLoaded = modelStates[info.name] { + info.isLoaded = await maybeLoaded.isModelLoaded() + if let loadedComputeUnits = await maybeLoaded.loadedModelComputeUnits() { + info.loadedComputeUnits = loadedComputeUnits + } + } else { + info.isLoaded = false + info.loadedComputeUnits = .init() + } + + if info.attention == .splitEinSum { + info.supportedComputeUnits = [ + .cpuAndGpu, + .cpuAndNeuralEngine, + .cpu, + .all + ] + } else { + info.supportedComputeUnits = [ + .cpuAndGpu, + .cpu + ] + } + + results.append(info) + } + return results } public func createModelState(name: String) throws -> ModelState { @@ -53,13 +82,14 @@ public actor ModelManager { private func addModel(url: URL) throws { var info = SdModelInfo() info.name = url.lastPathComponent - let attention = getModelAttention(url) - info.attention = attention ?? "unknown" + if let attention = getModelAttention(url) { + info.attention = attention + } modelInfos[info.name] = info modelUrls[info.name] = url } - private func getModelAttention(_ url: URL) -> String? { + private func getModelAttention(_ url: URL) -> SdModelAttention? { let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json") struct ModelMetadata: Decodable { @@ -74,7 +104,7 @@ public actor ModelManager { return nil } - return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original" + return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? SdModelAttention.splitEinSum : SdModelAttention.original } catch { return nil } diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index 0d58c21..ba27bcf 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -7,14 +7,15 @@ public actor ModelState { private let url: URL private var pipeline: StableDiffusionPipeline? private var tokenizer: BPETokenizer? + private var loadedConfiguration: MLModelConfiguration? public init(url: URL) { self.url = url } - public func load() throws { + public func load(request: SdLoadModelRequest) throws { let config = MLModelConfiguration() - config.computeUnits = .cpuAndGPU + config.computeUnits = request.computeUnits.toMlComputeUnits() pipeline = try StableDiffusionPipeline( resourcesAt: url, controlNet: [], @@ -26,6 +27,15 @@ public actor ModelState { let vocabUrl = url.appending(component: "vocab.json") tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl) try pipeline?.loadResources() + loadedConfiguration = config + } + + public func isModelLoaded() -> Bool { + pipeline != nil + } + + public func loadedModelComputeUnits() -> SdComputeUnits? { + loadedConfiguration?.computeUnits.toSdComputeUnits() } public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse { @@ -33,20 +43,25 @@ public actor ModelState { throw SdCoreError.modelNotLoaded } + let baseSeed: UInt32 = request.seed + var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt) pipelineConfig.negativePrompt = request.negativePrompt - pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max) - + pipelineConfig.imageCount = Int(request.batchSize) var response = SdGenerateImagesResponse() - for _ in 0 ..< request.imageCount { + for _ in 0 ..< request.batchCount { + var seed = baseSeed + if seed == 0 { + seed = UInt32.random(in: 0 ..< UInt32.max) + } + pipelineConfig.seed = seed let images = try pipeline.generateImages(configuration: pipelineConfig) for cgImage in images { guard let cgImage else { continue } - var image = SdImage() - image.content = try cgImage.toPngData() - response.images.append(image) + try response.images.append(cgImage.toSdImage(format: request.outputImageFormat)) } + response.seeds.append(seed) } return response } diff --git a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift index 10880c5..e0eff21 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift @@ -26,6 +26,9 @@ import NIOConcurrencyHelpers import SwiftProtobuf +///* +/// The model service, for management and loading of models. +/// /// Usage: instantiate `SdModelServiceClient`, then call methods of this protocol to make API calls. public protocol SdModelServiceClientProtocol: GRPCClient { var serviceName: String { get } @@ -36,11 +39,6 @@ public protocol SdModelServiceClientProtocol: GRPCClient { callOptions: CallOptions? ) -> UnaryCall - func reloadModels( - _ request: SdReloadModelsRequest, - callOptions: CallOptions? - ) -> UnaryCall - func loadModel( _ request: SdLoadModelRequest, callOptions: CallOptions? @@ -52,7 +50,9 @@ extension SdModelServiceClientProtocol { return "gay.pizza.stable.diffusion.ModelService" } - /// Unary call to ListModels + ///* + /// Lists the available models on the host. + /// This will return both models that are currently loaded, and models that are not yet loaded. /// /// - Parameters: /// - request: Request to send to ListModels. @@ -70,25 +70,8 @@ extension SdModelServiceClientProtocol { ) } - /// Unary call to ReloadModels - /// - /// - Parameters: - /// - request: Request to send to ReloadModels. - /// - callOptions: Call options. - /// - Returns: A `UnaryCall` with futures for the metadata, status and response. - public func reloadModels( - _ request: SdReloadModelsRequest, - callOptions: CallOptions? = nil - ) -> UnaryCall { - return self.makeUnaryCall( - path: SdModelServiceClientMetadata.Methods.reloadModels.path, - request: request, - callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] - ) - } - - /// Unary call to LoadModel + ///* + /// Loads a model onto a compute unit. /// /// - Parameters: /// - request: Request to send to LoadModel. @@ -167,6 +150,8 @@ public struct SdModelServiceNIOClient: SdModelServiceClientProtocol { } #if compiler(>=5.6) +///* +/// The model service, for management and loading of models. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol SdModelServiceAsyncClientProtocol: GRPCClient { static var serviceDescriptor: GRPCServiceDescriptor { get } @@ -177,11 +162,6 @@ public protocol SdModelServiceAsyncClientProtocol: GRPCClient { callOptions: CallOptions? ) -> GRPCAsyncUnaryCall - func makeReloadModelsCall( - _ request: SdReloadModelsRequest, - callOptions: CallOptions? - ) -> GRPCAsyncUnaryCall - func makeLoadModelCall( _ request: SdLoadModelRequest, callOptions: CallOptions? @@ -210,18 +190,6 @@ extension SdModelServiceAsyncClientProtocol { ) } - public func makeReloadModelsCall( - _ request: SdReloadModelsRequest, - callOptions: CallOptions? = nil - ) -> GRPCAsyncUnaryCall { - return self.makeAsyncUnaryCall( - path: SdModelServiceClientMetadata.Methods.reloadModels.path, - request: request, - callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] - ) - } - public func makeLoadModelCall( _ request: SdLoadModelRequest, callOptions: CallOptions? = nil @@ -249,18 +217,6 @@ extension SdModelServiceAsyncClientProtocol { ) } - public func reloadModels( - _ request: SdReloadModelsRequest, - callOptions: CallOptions? = nil - ) async throws -> SdReloadModelsResponse { - return try await self.performAsyncUnaryCall( - path: SdModelServiceClientMetadata.Methods.reloadModels.path, - request: request, - callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [] - ) - } - public func loadModel( _ request: SdLoadModelRequest, callOptions: CallOptions? = nil @@ -298,9 +254,6 @@ public protocol SdModelServiceClientInterceptorFactoryProtocol: GRPCSendable { /// - Returns: Interceptors to use when invoking 'listModels'. func makeListModelsInterceptors() -> [ClientInterceptor] - /// - Returns: Interceptors to use when invoking 'reloadModels'. - func makeReloadModelsInterceptors() -> [ClientInterceptor] - /// - Returns: Interceptors to use when invoking 'loadModel'. func makeLoadModelInterceptors() -> [ClientInterceptor] } @@ -311,7 +264,6 @@ public enum SdModelServiceClientMetadata { fullName: "gay.pizza.stable.diffusion.ModelService", methods: [ SdModelServiceClientMetadata.Methods.listModels, - SdModelServiceClientMetadata.Methods.reloadModels, SdModelServiceClientMetadata.Methods.loadModel, ] ) @@ -323,12 +275,6 @@ public enum SdModelServiceClientMetadata { type: GRPCCallType.unary ) - public static let reloadModels = GRPCMethodDescriptor( - name: "ReloadModels", - path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels", - type: GRPCCallType.unary - ) - public static let loadModel = GRPCMethodDescriptor( name: "LoadModel", path: "/gay.pizza.stable.diffusion.ModelService/LoadModel", @@ -337,12 +283,15 @@ public enum SdModelServiceClientMetadata { } } +///* +/// The image generation service, for generating images from loaded models. +/// /// Usage: instantiate `SdImageGenerationServiceClient`, then call methods of this protocol to make API calls. public protocol SdImageGenerationServiceClientProtocol: GRPCClient { var serviceName: String { get } var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get } - func generateImage( + func generateImages( _ request: SdGenerateImagesRequest, callOptions: CallOptions? ) -> UnaryCall @@ -353,21 +302,22 @@ extension SdImageGenerationServiceClientProtocol { return "gay.pizza.stable.diffusion.ImageGenerationService" } - /// Unary call to GenerateImage + ///* + /// Generates images using a loaded model. /// /// - Parameters: - /// - request: Request to send to GenerateImage. + /// - request: Request to send to GenerateImages. /// - callOptions: Call options. /// - Returns: A `UnaryCall` with futures for the metadata, status and response. - public func generateImage( + public func generateImages( _ request: SdGenerateImagesRequest, callOptions: CallOptions? = nil ) -> UnaryCall { return self.makeUnaryCall( - path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path, request: request, callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } } @@ -432,12 +382,14 @@ public struct SdImageGenerationServiceNIOClient: SdImageGenerationServiceClientP } #if compiler(>=5.6) +///* +/// The image generation service, for generating images from loaded models. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient { static var serviceDescriptor: GRPCServiceDescriptor { get } var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get } - func makeGenerateImageCall( + func makeGenerateImagesCall( _ request: SdGenerateImagesRequest, callOptions: CallOptions? ) -> GRPCAsyncUnaryCall @@ -453,30 +405,30 @@ extension SdImageGenerationServiceAsyncClientProtocol { return nil } - public func makeGenerateImageCall( + public func makeGenerateImagesCall( _ request: SdGenerateImagesRequest, callOptions: CallOptions? = nil ) -> GRPCAsyncUnaryCall { return self.makeAsyncUnaryCall( - path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path, request: request, callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension SdImageGenerationServiceAsyncClientProtocol { - public func generateImage( + public func generateImages( _ request: SdGenerateImagesRequest, callOptions: CallOptions? = nil ) async throws -> SdGenerateImagesResponse { return try await self.performAsyncUnaryCall( - path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path, + path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path, request: request, callOptions: callOptions ?? self.defaultCallOptions, - interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [] + interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [] ) } } @@ -502,8 +454,8 @@ public struct SdImageGenerationServiceAsyncClient: SdImageGenerationServiceAsync public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSendable { - /// - Returns: Interceptors to use when invoking 'generateImage'. - func makeGenerateImageInterceptors() -> [ClientInterceptor] + /// - Returns: Interceptors to use when invoking 'generateImages'. + func makeGenerateImagesInterceptors() -> [ClientInterceptor] } public enum SdImageGenerationServiceClientMetadata { @@ -511,27 +463,33 @@ public enum SdImageGenerationServiceClientMetadata { name: "ImageGenerationService", fullName: "gay.pizza.stable.diffusion.ImageGenerationService", methods: [ - SdImageGenerationServiceClientMetadata.Methods.generateImage, + SdImageGenerationServiceClientMetadata.Methods.generateImages, ] ) public enum Methods { - public static let generateImage = GRPCMethodDescriptor( - name: "GenerateImage", - path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage", + public static let generateImages = GRPCMethodDescriptor( + name: "GenerateImages", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages", type: GRPCCallType.unary ) } } +///* +/// The model service, for management and loading of models. +/// /// To build a server, implement a class that conforms to this protocol. public protocol SdModelServiceProvider: CallHandlerProvider { var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get } + ///* + /// Lists the available models on the host. + /// This will return both models that are currently loaded, and models that are not yet loaded. func listModels(request: SdListModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture - func reloadModels(request: SdReloadModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture - + ///* + /// Loads a model onto a compute unit. func loadModel(request: SdLoadModelRequest, context: StatusOnlyCallContext) -> EventLoopFuture } @@ -556,15 +514,6 @@ extension SdModelServiceProvider { userFunction: self.listModels(request:context:) ) - case "ReloadModels": - return UnaryServerHandler( - context: context, - requestDeserializer: ProtobufDeserializer(), - responseSerializer: ProtobufSerializer(), - interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [], - userFunction: self.reloadModels(request:context:) - ) - case "LoadModel": return UnaryServerHandler( context: context, @@ -582,22 +531,25 @@ extension SdModelServiceProvider { #if compiler(>=5.6) +///* +/// The model service, for management and loading of models. +/// /// To implement a server, implement an object which conforms to this protocol. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol SdModelServiceAsyncProvider: CallHandlerProvider { static var serviceDescriptor: GRPCServiceDescriptor { get } var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get } + ///* + /// Lists the available models on the host. + /// This will return both models that are currently loaded, and models that are not yet loaded. @Sendable func listModels( request: SdListModelsRequest, context: GRPCAsyncServerCallContext ) async throws -> SdListModelsResponse - @Sendable func reloadModels( - request: SdReloadModelsRequest, - context: GRPCAsyncServerCallContext - ) async throws -> SdReloadModelsResponse - + ///* + /// Loads a model onto a compute unit. @Sendable func loadModel( request: SdLoadModelRequest, context: GRPCAsyncServerCallContext @@ -632,15 +584,6 @@ extension SdModelServiceAsyncProvider { wrapping: self.listModels(request:context:) ) - case "ReloadModels": - return GRPCAsyncServerHandler( - context: context, - requestDeserializer: ProtobufDeserializer(), - responseSerializer: ProtobufSerializer(), - interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [], - wrapping: self.reloadModels(request:context:) - ) - case "LoadModel": return GRPCAsyncServerHandler( context: context, @@ -664,10 +607,6 @@ public protocol SdModelServiceServerInterceptorFactoryProtocol { /// Defaults to calling `self.makeInterceptors()`. func makeListModelsInterceptors() -> [ServerInterceptor] - /// - Returns: Interceptors to use when handling 'reloadModels'. - /// Defaults to calling `self.makeInterceptors()`. - func makeReloadModelsInterceptors() -> [ServerInterceptor] - /// - Returns: Interceptors to use when handling 'loadModel'. /// Defaults to calling `self.makeInterceptors()`. func makeLoadModelInterceptors() -> [ServerInterceptor] @@ -679,7 +618,6 @@ public enum SdModelServiceServerMetadata { fullName: "gay.pizza.stable.diffusion.ModelService", methods: [ SdModelServiceServerMetadata.Methods.listModels, - SdModelServiceServerMetadata.Methods.reloadModels, SdModelServiceServerMetadata.Methods.loadModel, ] ) @@ -691,12 +629,6 @@ public enum SdModelServiceServerMetadata { type: GRPCCallType.unary ) - public static let reloadModels = GRPCMethodDescriptor( - name: "ReloadModels", - path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels", - type: GRPCCallType.unary - ) - public static let loadModel = GRPCMethodDescriptor( name: "LoadModel", path: "/gay.pizza.stable.diffusion.ModelService/LoadModel", @@ -704,11 +636,16 @@ public enum SdModelServiceServerMetadata { ) } } +///* +/// The image generation service, for generating images from loaded models. +/// /// To build a server, implement a class that conforms to this protocol. public protocol SdImageGenerationServiceProvider: CallHandlerProvider { var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get } - func generateImage(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture + ///* + /// Generates images using a loaded model. + func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture } extension SdImageGenerationServiceProvider { @@ -723,13 +660,13 @@ extension SdImageGenerationServiceProvider { context: CallHandlerContext ) -> GRPCServerHandlerProtocol? { switch name { - case "GenerateImage": + case "GenerateImages": return UnaryServerHandler( context: context, requestDeserializer: ProtobufDeserializer(), responseSerializer: ProtobufSerializer(), - interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [], - userFunction: self.generateImage(request:context:) + interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [], + userFunction: self.generateImages(request:context:) ) default: @@ -740,13 +677,18 @@ extension SdImageGenerationServiceProvider { #if compiler(>=5.6) +///* +/// The image generation service, for generating images from loaded models. +/// /// To implement a server, implement an object which conforms to this protocol. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider { static var serviceDescriptor: GRPCServiceDescriptor { get } var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get } - @Sendable func generateImage( + ///* + /// Generates images using a loaded model. + @Sendable func generateImages( request: SdGenerateImagesRequest, context: GRPCAsyncServerCallContext ) async throws -> SdGenerateImagesResponse @@ -771,13 +713,13 @@ extension SdImageGenerationServiceAsyncProvider { context: CallHandlerContext ) -> GRPCServerHandlerProtocol? { switch name { - case "GenerateImage": + case "GenerateImages": return GRPCAsyncServerHandler( context: context, requestDeserializer: ProtobufDeserializer(), responseSerializer: ProtobufSerializer(), - interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [], - wrapping: self.generateImage(request:context:) + interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [], + wrapping: self.generateImages(request:context:) ) default: @@ -790,9 +732,9 @@ extension SdImageGenerationServiceAsyncProvider { public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol { - /// - Returns: Interceptors to use when handling 'generateImage'. + /// - Returns: Interceptors to use when handling 'generateImages'. /// Defaults to calling `self.makeInterceptors()`. - func makeGenerateImageInterceptors() -> [ServerInterceptor] + func makeGenerateImagesInterceptors() -> [ServerInterceptor] } public enum SdImageGenerationServiceServerMetadata { @@ -800,14 +742,14 @@ public enum SdImageGenerationServiceServerMetadata { name: "ImageGenerationService", fullName: "gay.pizza.stable.diffusion.ImageGenerationService", methods: [ - SdImageGenerationServiceServerMetadata.Methods.generateImage, + SdImageGenerationServiceServerMetadata.Methods.generateImages, ] ) public enum Methods { - public static let generateImage = GRPCMethodDescriptor( - name: "GenerateImage", - path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage", + public static let generateImages = GRPCMethodDescriptor( + name: "GenerateImages", + path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages", type: GRPCCallType.unary ) } diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index 4c4360c..25ea181 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -7,6 +7,9 @@ // For information on using the generated types, please see the documentation: // https://github.com/apple/swift-protobuf/ +///* +/// Stable Diffusion RPC service for Apple Platforms. + import Foundation import SwiftProtobuf @@ -20,9 +23,67 @@ fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAP typealias Version = _2 } +///* +/// Represents the model attention. Model attention has to do with how the model is encoded, and +/// can determine what compute units are able to support a particular model. +public enum SdModelAttention: SwiftProtobuf.Enum { + public typealias RawValue = Int + + ///* + /// The model is an original attention type. It can be loaded only onto CPU & GPU compute units. + case original // = 0 + + ///* + /// The model is a split-ein-sum attention type. It can be loaded onto all compute units, + /// including the Apple Neural Engine. + case splitEinSum // = 1 + case UNRECOGNIZED(Int) + + public init() { + self = .original + } + + public init?(rawValue: Int) { + switch rawValue { + case 0: self = .original + case 1: self = .splitEinSum + default: self = .UNRECOGNIZED(rawValue) + } + } + + public var rawValue: Int { + switch self { + case .original: return 0 + case .splitEinSum: return 1 + case .UNRECOGNIZED(let i): return i + } + } + +} + +#if swift(>=4.2) + +extension SdModelAttention: CaseIterable { + // The compiler won't synthesize support with the UNRECOGNIZED case. + public static var allCases: [SdModelAttention] = [ + .original, + .splitEinSum, + ] +} + +#endif // swift(>=4.2) + +///* +/// Represents the schedulers that are used to sample images. public enum SdScheduler: SwiftProtobuf.Enum { public typealias RawValue = Int + + ///* + /// The PNDM (Pseudo numerical methods for diffusion models) scheduler. case pndm // = 0 + + ///* + /// The DPM-Solver++ scheduler. case dpmSolverPlusPlus // = 1 case UNRECOGNIZED(Int) @@ -60,11 +121,25 @@ extension SdScheduler: CaseIterable { #endif // swift(>=4.2) +///* +/// Represents a specifier for what compute units are available for ML tasks. public enum SdComputeUnits: SwiftProtobuf.Enum { public typealias RawValue = Int + + ///* + /// The CPU as a singular compute unit. case cpu // = 0 + + ///* + /// The CPU & GPU combined into a singular compute unit. case cpuAndGpu // = 1 + + ///* + /// Allow the usage of all compute units. CoreML will decided where the model is loaded. case all // = 2 + + ///* + /// The CPU & Neural Engine combined into a singular compute unit. case cpuAndNeuralEngine // = 3 case UNRECOGNIZED(Int) @@ -108,34 +183,108 @@ extension SdComputeUnits: CaseIterable { #endif // swift(>=4.2) +///* +/// Represents the format of an image. +public enum SdImageFormat: SwiftProtobuf.Enum { + public typealias RawValue = Int + + ///* + /// The PNG image format. + case png // = 0 + case UNRECOGNIZED(Int) + + public init() { + self = .png + } + + public init?(rawValue: Int) { + switch rawValue { + case 0: self = .png + default: self = .UNRECOGNIZED(rawValue) + } + } + + public var rawValue: Int { + switch self { + case .png: return 0 + case .UNRECOGNIZED(let i): return i + } + } + +} + +#if swift(>=4.2) + +extension SdImageFormat: CaseIterable { + // The compiler won't synthesize support with the UNRECOGNIZED case. + public static var allCases: [SdImageFormat] = [ + .png, + ] +} + +#endif // swift(>=4.2) + +///* +/// Represents information about an available model. +/// The primary key of a model is it's 'name' field. public struct SdModelInfo { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The name of the available model. Note that within the context of a single RPC server, + /// the name of a model is a unique identifier. This may not be true when utilizing a cluster or + /// load balanced server, so keep that in mind. public var name: String = String() - public var attention: String = String() + ///* + /// The attention of the model. Model attention determines what compute units can be used to + /// load the model and make predictions. + public var attention: SdModelAttention = .original + ///* + /// Whether the model is currently loaded onto an available compute unit. public var isLoaded: Bool = false + ///* + /// The compute unit that the model is currently loaded into, if it is loaded to one at all. + /// When is_loaded is false, the value of this field should be null. + public var loadedComputeUnits: SdComputeUnits = .cpu + + ///* + /// The compute units that this model supports using. + public var supportedComputeUnits: [SdComputeUnits] = [] + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents an image within the Stable Diffusion context. +/// This could be an input image for an image generation request, or it could be +/// a generated image from the Stable Diffusion model. public struct SdImage { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. - public var content: Data = Data() + ///* + /// The format of the image. + public var format: SdImageFormat = .png + + ///* + /// The raw data of the image, in the specified format. + public var data: Data = Data() public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents a request to list the models available on the host. public struct SdListModelsRequest { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for @@ -146,54 +295,44 @@ public struct SdListModelsRequest { public init() {} } +///* +/// Represents a response to listing the models available on the host. public struct SdListModelsResponse { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. - public var models: [SdModelInfo] = [] - - public var unknownFields = SwiftProtobuf.UnknownStorage() - - public init() {} -} - -public struct SdReloadModelsRequest { - // SwiftProtobuf.Message conformance is added in an extension below. See the - // `Message` and `Message+*Additions` files in the SwiftProtobuf library for - // methods supported on all messages. - - public var unknownFields = SwiftProtobuf.UnknownStorage() - - public init() {} -} - -public struct SdReloadModelsResponse { - // SwiftProtobuf.Message conformance is added in an extension below. See the - // `Message` and `Message+*Additions` files in the SwiftProtobuf library for - // methods supported on all messages. + ///* + /// The available models on the Stable Diffusion server. + public var availableModels: [SdModelInfo] = [] public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents a request to load a model into a specified compute unit. public struct SdLoadModelRequest { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The model name to load onto the compute unit. public var modelName: String = String() + ///* + /// The compute units to load the model onto. public var computeUnits: SdComputeUnits = .cpu - public var reduceMemory: Bool = false - public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents a response to loading a model. public struct SdLoadModelResponse { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for @@ -204,45 +343,77 @@ public struct SdLoadModelResponse { public init() {} } +///* +/// Represents a request to generate images using a loaded model. public struct SdGenerateImagesRequest { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The model name to use for generation. + /// The model must be already be loaded using ModelService.LoadModel RPC method. public var modelName: String = String() - public var imageCount: UInt32 = 0 + ///* + /// The output format for generated images. + public var outputImageFormat: SdImageFormat = .png + ///* + /// The number of batches of images to generate. + public var batchCount: UInt32 = 0 + + ///* + /// The number of images inside a single batch. + public var batchSize: UInt32 = 0 + + ///* + /// The positive textual prompt for image generation. public var prompt: String = String() + ///* + /// The negative prompt for image generation. public var negativePrompt: String = String() + ///* + /// The random seed to use. + /// Zero indicates that the seed should be random. + public var seed: UInt32 = 0 + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } +///* +/// Represents the response from image generation. public struct SdGenerateImagesResponse { // SwiftProtobuf.Message conformance is added in an extension below. See the // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + ///* + /// The set of generated images by the Stable Diffusion pipeline. public var images: [SdImage] = [] + ///* + /// The seeds that were used to generate the images. + public var seeds: [UInt32] = [] + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} } #if swift(>=5.5) && canImport(_Concurrency) +extension SdModelAttention: @unchecked Sendable {} extension SdScheduler: @unchecked Sendable {} extension SdComputeUnits: @unchecked Sendable {} +extension SdImageFormat: @unchecked Sendable {} extension SdModelInfo: @unchecked Sendable {} extension SdImage: @unchecked Sendable {} extension SdListModelsRequest: @unchecked Sendable {} extension SdListModelsResponse: @unchecked Sendable {} -extension SdReloadModelsRequest: @unchecked Sendable {} -extension SdReloadModelsResponse: @unchecked Sendable {} extension SdLoadModelRequest: @unchecked Sendable {} extension SdLoadModelResponse: @unchecked Sendable {} extension SdGenerateImagesRequest: @unchecked Sendable {} @@ -253,10 +424,17 @@ extension SdGenerateImagesResponse: @unchecked Sendable {} fileprivate let _protobuf_package = "gay.pizza.stable.diffusion" +extension SdModelAttention: SwiftProtobuf._ProtoNameProviding { + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 0: .same(proto: "original"), + 1: .same(proto: "split_ein_sum"), + ] +} + extension SdScheduler: SwiftProtobuf._ProtoNameProviding { public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 0: .same(proto: "pndm"), - 1: .same(proto: "dpmSolverPlusPlus"), + 1: .same(proto: "dpm_solver_plus_plus"), ] } @@ -269,12 +447,20 @@ extension SdComputeUnits: SwiftProtobuf._ProtoNameProviding { ] } +extension SdImageFormat: SwiftProtobuf._ProtoNameProviding { + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 0: .same(proto: "png"), + ] +} + extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { public static let protoMessageName: String = _protobuf_package + ".ModelInfo" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .same(proto: "name"), 2: .same(proto: "attention"), 3: .standard(proto: "is_loaded"), + 4: .standard(proto: "loaded_compute_units"), + 5: .standard(proto: "supported_compute_units"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -284,8 +470,10 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { case 1: try { try decoder.decodeSingularStringField(value: &self.name) }() - case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }() + case 2: try { try decoder.decodeSingularEnumField(value: &self.attention) }() case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }() + case 4: try { try decoder.decodeSingularEnumField(value: &self.loadedComputeUnits) }() + case 5: try { try decoder.decodeRepeatedEnumField(value: &self.supportedComputeUnits) }() default: break } } @@ -295,12 +483,18 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati if !self.name.isEmpty { try visitor.visitSingularStringField(value: self.name, fieldNumber: 1) } - if !self.attention.isEmpty { - try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2) + if self.attention != .original { + try visitor.visitSingularEnumField(value: self.attention, fieldNumber: 2) } if self.isLoaded != false { try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3) } + if self.loadedComputeUnits != .cpu { + try visitor.visitSingularEnumField(value: self.loadedComputeUnits, fieldNumber: 4) + } + if !self.supportedComputeUnits.isEmpty { + try visitor.visitPackedEnumField(value: self.supportedComputeUnits, fieldNumber: 5) + } try unknownFields.traverse(visitor: &visitor) } @@ -308,6 +502,8 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati if lhs.name != rhs.name {return false} if lhs.attention != rhs.attention {return false} if lhs.isLoaded != rhs.isLoaded {return false} + if lhs.loadedComputeUnits != rhs.loadedComputeUnits {return false} + if lhs.supportedComputeUnits != rhs.supportedComputeUnits {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -316,7 +512,8 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { public static let protoMessageName: String = _protobuf_package + ".Image" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ - 1: .same(proto: "content"), + 1: .same(proto: "format"), + 2: .same(proto: "data"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -325,21 +522,26 @@ extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBa // allocates stack space for every case branch when no optimizations are // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { - case 1: try { try decoder.decodeSingularBytesField(value: &self.content) }() + case 1: try { try decoder.decodeSingularEnumField(value: &self.format) }() + case 2: try { try decoder.decodeSingularBytesField(value: &self.data) }() default: break } } } public func traverse(visitor: inout V) throws { - if !self.content.isEmpty { - try visitor.visitSingularBytesField(value: self.content, fieldNumber: 1) + if self.format != .png { + try visitor.visitSingularEnumField(value: self.format, fieldNumber: 1) + } + if !self.data.isEmpty { + try visitor.visitSingularBytesField(value: self.data, fieldNumber: 2) } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdImage, rhs: SdImage) -> Bool { - if lhs.content != rhs.content {return false} + if lhs.format != rhs.format {return false} + if lhs.data != rhs.data {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -367,7 +569,7 @@ extension SdListModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImpl extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { public static let protoMessageName: String = _protobuf_package + ".ListModelsResponse" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ - 1: .same(proto: "models"), + 1: .standard(proto: "available_models"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -376,59 +578,21 @@ extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImp // allocates stack space for every case branch when no optimizations are // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { - case 1: try { try decoder.decodeRepeatedMessageField(value: &self.models) }() + case 1: try { try decoder.decodeRepeatedMessageField(value: &self.availableModels) }() default: break } } } public func traverse(visitor: inout V) throws { - if !self.models.isEmpty { - try visitor.visitRepeatedMessageField(value: self.models, fieldNumber: 1) + if !self.availableModels.isEmpty { + try visitor.visitRepeatedMessageField(value: self.availableModels, fieldNumber: 1) } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdListModelsResponse, rhs: SdListModelsResponse) -> Bool { - if lhs.models != rhs.models {return false} - if lhs.unknownFields != rhs.unknownFields {return false} - return true - } -} - -extension SdReloadModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { - public static let protoMessageName: String = _protobuf_package + ".ReloadModelsRequest" - public static let _protobuf_nameMap = SwiftProtobuf._NameMap() - - public mutating func decodeMessage(decoder: inout D) throws { - while let _ = try decoder.nextFieldNumber() { - } - } - - public func traverse(visitor: inout V) throws { - try unknownFields.traverse(visitor: &visitor) - } - - public static func ==(lhs: SdReloadModelsRequest, rhs: SdReloadModelsRequest) -> Bool { - if lhs.unknownFields != rhs.unknownFields {return false} - return true - } -} - -extension SdReloadModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { - public static let protoMessageName: String = _protobuf_package + ".ReloadModelsResponse" - public static let _protobuf_nameMap = SwiftProtobuf._NameMap() - - public mutating func decodeMessage(decoder: inout D) throws { - while let _ = try decoder.nextFieldNumber() { - } - } - - public func traverse(visitor: inout V) throws { - try unknownFields.traverse(visitor: &visitor) - } - - public static func ==(lhs: SdReloadModelsResponse, rhs: SdReloadModelsResponse) -> Bool { + if lhs.availableModels != rhs.availableModels {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -439,7 +603,6 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .standard(proto: "model_name"), 2: .standard(proto: "compute_units"), - 3: .standard(proto: "reduce_memory"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -450,7 +613,6 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple switch fieldNumber { case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }() case 2: try { try decoder.decodeSingularEnumField(value: &self.computeUnits) }() - case 3: try { try decoder.decodeSingularBoolField(value: &self.reduceMemory) }() default: break } } @@ -463,16 +625,12 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple if self.computeUnits != .cpu { try visitor.visitSingularEnumField(value: self.computeUnits, fieldNumber: 2) } - if self.reduceMemory != false { - try visitor.visitSingularBoolField(value: self.reduceMemory, fieldNumber: 3) - } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdLoadModelRequest, rhs: SdLoadModelRequest) -> Bool { if lhs.modelName != rhs.modelName {return false} if lhs.computeUnits != rhs.computeUnits {return false} - if lhs.reduceMemory != rhs.reduceMemory {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -501,9 +659,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message public static let protoMessageName: String = _protobuf_package + ".GenerateImagesRequest" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .standard(proto: "model_name"), - 2: .standard(proto: "image_count"), - 3: .same(proto: "prompt"), - 4: .standard(proto: "negative_prompt"), + 2: .standard(proto: "output_image_format"), + 3: .standard(proto: "batch_count"), + 4: .standard(proto: "batch_size"), + 5: .same(proto: "prompt"), + 6: .standard(proto: "negative_prompt"), + 7: .same(proto: "seed"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -513,9 +674,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }() - case 2: try { try decoder.decodeSingularUInt32Field(value: &self.imageCount) }() - case 3: try { try decoder.decodeSingularStringField(value: &self.prompt) }() - case 4: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }() + case 2: try { try decoder.decodeSingularEnumField(value: &self.outputImageFormat) }() + case 3: try { try decoder.decodeSingularUInt32Field(value: &self.batchCount) }() + case 4: try { try decoder.decodeSingularUInt32Field(value: &self.batchSize) }() + case 5: try { try decoder.decodeSingularStringField(value: &self.prompt) }() + case 6: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }() + case 7: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }() default: break } } @@ -525,23 +689,35 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message if !self.modelName.isEmpty { try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1) } - if self.imageCount != 0 { - try visitor.visitSingularUInt32Field(value: self.imageCount, fieldNumber: 2) + if self.outputImageFormat != .png { + try visitor.visitSingularEnumField(value: self.outputImageFormat, fieldNumber: 2) + } + if self.batchCount != 0 { + try visitor.visitSingularUInt32Field(value: self.batchCount, fieldNumber: 3) + } + if self.batchSize != 0 { + try visitor.visitSingularUInt32Field(value: self.batchSize, fieldNumber: 4) } if !self.prompt.isEmpty { - try visitor.visitSingularStringField(value: self.prompt, fieldNumber: 3) + try visitor.visitSingularStringField(value: self.prompt, fieldNumber: 5) } if !self.negativePrompt.isEmpty { - try visitor.visitSingularStringField(value: self.negativePrompt, fieldNumber: 4) + try visitor.visitSingularStringField(value: self.negativePrompt, fieldNumber: 6) + } + if self.seed != 0 { + try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 7) } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdGenerateImagesRequest, rhs: SdGenerateImagesRequest) -> Bool { if lhs.modelName != rhs.modelName {return false} - if lhs.imageCount != rhs.imageCount {return false} + if lhs.outputImageFormat != rhs.outputImageFormat {return false} + if lhs.batchCount != rhs.batchCount {return false} + if lhs.batchSize != rhs.batchSize {return false} if lhs.prompt != rhs.prompt {return false} if lhs.negativePrompt != rhs.negativePrompt {return false} + if lhs.seed != rhs.seed {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -551,6 +727,7 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag public static let protoMessageName: String = _protobuf_package + ".GenerateImagesResponse" public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .same(proto: "images"), + 2: .same(proto: "seeds"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -560,6 +737,7 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }() + case 2: try { try decoder.decodeRepeatedUInt32Field(value: &self.seeds) }() default: break } } @@ -569,11 +747,15 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag if !self.images.isEmpty { try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1) } + if !self.seeds.isEmpty { + try visitor.visitPackedUInt32Field(value: self.seeds, fieldNumber: 2) + } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdGenerateImagesResponse, rhs: SdGenerateImagesResponse) -> Bool { if lhs.images != rhs.images {return false} + if lhs.seeds != rhs.seeds {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } diff --git a/Sources/StableDiffusionProtos/Utilities.swift b/Sources/StableDiffusionProtos/Utilities.swift new file mode 100644 index 0000000..57de721 --- /dev/null +++ b/Sources/StableDiffusionProtos/Utilities.swift @@ -0,0 +1,26 @@ +import CoreML +import Foundation + +public extension SdComputeUnits { + func toMlComputeUnits() -> MLComputeUnits { + switch self { + case .all: return .all + case .cpu: return .cpuOnly + case .cpuAndGpu: return .cpuAndGPU + case .cpuAndNeuralEngine: return .cpuAndNeuralEngine + default: return .all + } + } +} + +public extension MLComputeUnits { + func toSdComputeUnits() -> SdComputeUnits { + switch self { + case .all: return .all + case .cpuOnly: return .cpu + case .cpuAndGPU: return .cpuAndGpu + case .cpuAndNeuralEngine: return .cpuAndNeuralEngine + default: return .all + } + } +} diff --git a/Sources/StableDiffusionServer/ImageGenerationService.swift b/Sources/StableDiffusionServer/ImageGenerationService.swift index 0b33c2c..993922a 100644 --- a/Sources/StableDiffusionServer/ImageGenerationService.swift +++ b/Sources/StableDiffusionServer/ImageGenerationService.swift @@ -10,7 +10,7 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider { self.modelManager = modelManager } - func generateImage(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse { + func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse { guard let state = await modelManager.getModelState(name: request.modelName) else { throw SdCoreError.modelNotFound } diff --git a/Sources/StableDiffusionServer/ModelService.swift b/Sources/StableDiffusionServer/ModelService.swift index 17d2a9e..1a36dbf 100644 --- a/Sources/StableDiffusionServer/ModelService.swift +++ b/Sources/StableDiffusionServer/ModelService.swift @@ -11,20 +11,15 @@ class ModelServiceProvider: SdModelServiceAsyncProvider { } func listModels(request _: SdListModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse { - let models = await modelManager.listModels() + let models = try await modelManager.listAvailableModels() var response = SdListModelsResponse() - response.models.append(contentsOf: models) + response.availableModels.append(contentsOf: models) return response } - func reloadModels(request _: SdReloadModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse { - try await modelManager.reloadModels() - return SdReloadModelsResponse() - } - func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse { let state = try await modelManager.createModelState(name: request.modelName) - try await state.load() + try await state.load(request: request) return SdLoadModelResponse() } } diff --git a/Sources/StableDiffusionServer/main.swift b/Sources/StableDiffusionServer/main.swift index d0384ba..f82803c 100644 --- a/Sources/StableDiffusionServer/main.swift +++ b/Sources/StableDiffusionServer/main.swift @@ -16,7 +16,7 @@ struct ServerCommand: ParsableCommand { let semaphore = DispatchSemaphore(value: 0) Task { do { - try await modelManager.reloadModels() + try await modelManager.reloadAvailableModels() } catch { ServerCommand.exit(withError: error) }