mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-04 05:51:32 +00:00
Document API, make the implementation match the API, and update the same.
This commit is contained in:
@ -26,15 +26,6 @@ java {
|
|||||||
withSourcesJar()
|
withSourcesJar()
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceSets {
|
|
||||||
main {
|
|
||||||
proto {
|
|
||||||
srcDir("../../Common")
|
|
||||||
include("*.proto")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("org.jetbrains.kotlin:kotlin-bom")
|
implementation("org.jetbrains.kotlin:kotlin-bom")
|
||||||
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package gay.pizza.stable.diffusion.sample
|
package gay.pizza.stable.diffusion.sample
|
||||||
|
|
||||||
|
import gay.pizza.stable.diffusion.StableDiffusion
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||||
import io.grpc.ManagedChannelBuilder
|
import io.grpc.ManagedChannelBuilder
|
||||||
|
import kotlin.io.path.Path
|
||||||
|
import kotlin.io.path.writeBytes
|
||||||
import kotlin.system.exitProcess
|
import kotlin.system.exitProcess
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
@ -15,20 +18,22 @@ fun main() {
|
|||||||
|
|
||||||
val client = StableDiffusionRpcClient(channel)
|
val client = StableDiffusionRpcClient(channel)
|
||||||
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
|
val modelListResponse = client.modelServiceBlocking.listModels(ListModelsRequest.getDefaultInstance())
|
||||||
if (modelListResponse.modelsList.isEmpty()) {
|
if (modelListResponse.availableModelsList.isEmpty()) {
|
||||||
println("no available models")
|
println("no available models")
|
||||||
exitProcess(0)
|
exitProcess(0)
|
||||||
}
|
}
|
||||||
println("available models:")
|
println("available models:")
|
||||||
for (model in modelListResponse.modelsList) {
|
for (model in modelListResponse.availableModelsList) {
|
||||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}")
|
val maybeLoadedComputeUnits = if (model.isLoaded) " loaded_compute_units=${model.loadedComputeUnits.name}" else ""
|
||||||
|
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||||
}
|
}
|
||||||
|
|
||||||
val model = modelListResponse.modelsList.random()
|
val model = modelListResponse.availableModelsList.random()
|
||||||
if (!model.isLoaded) {
|
if (!model.isLoaded) {
|
||||||
println("loading model ${model.name}...")
|
println("loading model ${model.name}...")
|
||||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||||
modelName = model.name
|
modelName = model.name
|
||||||
|
computeUnits = model.supportedComputeUnitsList.first()
|
||||||
}.build())
|
}.build())
|
||||||
} else {
|
} else {
|
||||||
println("using model ${model.name}...")
|
println("using model ${model.name}...")
|
||||||
@ -36,14 +41,22 @@ fun main() {
|
|||||||
|
|
||||||
println("generating images...")
|
println("generating images...")
|
||||||
|
|
||||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImage(GenerateImagesRequest.newBuilder().apply {
|
val request = GenerateImagesRequest.newBuilder().apply {
|
||||||
modelName = model.name
|
modelName = model.name
|
||||||
imageCount = 1
|
outputImageFormat = StableDiffusion.ImageFormat.png
|
||||||
|
batchSize = 2
|
||||||
|
batchCount = 2
|
||||||
prompt = "cat"
|
prompt = "cat"
|
||||||
negativePrompt = "bad, low quality, nsfw"
|
negativePrompt = "bad, low quality, nsfw"
|
||||||
}.build())
|
}.build()
|
||||||
|
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
|
||||||
|
|
||||||
println("generated ${generateImagesResponse.imagesCount} images")
|
println("generated ${generateImagesResponse.imagesCount} images:")
|
||||||
|
for ((index, image) in generateImagesResponse.imagesList.withIndex()) {
|
||||||
|
println(" image ${index + 1} format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
|
val path = Path("work/image${index}.${image.format.name}")
|
||||||
|
path.writeBytes(image.data.toByteArray())
|
||||||
|
}
|
||||||
|
|
||||||
channel.shutdownNow()
|
channel.shutdownNow()
|
||||||
}
|
}
|
||||||
|
1
Clients/Java/src/main/proto
Symbolic link
1
Clients/Java/src/main/proto
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../../Common
|
@ -1,63 +1,247 @@
|
|||||||
|
/**
|
||||||
|
* Stable Diffusion RPC service for Apple Platforms.
|
||||||
|
*/
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
package gay.pizza.stable.diffusion;
|
package gay.pizza.stable.diffusion;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utilize a prefix of 'Sd' for Swift.
|
||||||
|
*/
|
||||||
option swift_prefix = "Sd";
|
option swift_prefix = "Sd";
|
||||||
|
|
||||||
message ModelInfo {
|
/**
|
||||||
string name = 1;
|
* Represents the model attention. Model attention has to do with how the model is encoded, and
|
||||||
string attention = 2;
|
* can determine what compute units are able to support a particular model.
|
||||||
bool is_loaded = 3;
|
*/
|
||||||
|
enum ModelAttention {
|
||||||
|
/**
|
||||||
|
* The model is an original attention type. It can be loaded only onto CPU & GPU compute units.
|
||||||
|
*/
|
||||||
|
original = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model is a split-ein-sum attention type. It can be loaded onto all compute units,
|
||||||
|
* including the Apple Neural Engine.
|
||||||
|
*/
|
||||||
|
split_ein_sum = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Image {
|
/**
|
||||||
bytes content = 1;
|
* Represents the schedulers that are used to sample images.
|
||||||
}
|
*/
|
||||||
|
|
||||||
message ListModelsRequest {}
|
|
||||||
message ListModelsResponse {
|
|
||||||
repeated ModelInfo models = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message ReloadModelsRequest {}
|
|
||||||
message ReloadModelsResponse {}
|
|
||||||
|
|
||||||
enum Scheduler {
|
enum Scheduler {
|
||||||
|
/**
|
||||||
|
* The PNDM (Pseudo numerical methods for diffusion models) scheduler.
|
||||||
|
*/
|
||||||
pndm = 0;
|
pndm = 0;
|
||||||
dpmSolverPlusPlus = 1;
|
|
||||||
|
/**
|
||||||
|
* The DPM-Solver++ scheduler.
|
||||||
|
*/
|
||||||
|
dpm_solver_plus_plus = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a specifier for what compute units are available for ML tasks.
|
||||||
|
*/
|
||||||
enum ComputeUnits {
|
enum ComputeUnits {
|
||||||
|
/**
|
||||||
|
* The CPU as a singular compute unit.
|
||||||
|
*/
|
||||||
cpu = 0;
|
cpu = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The CPU & GPU combined into a singular compute unit.
|
||||||
|
*/
|
||||||
cpu_and_gpu = 1;
|
cpu_and_gpu = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allow the usage of all compute units. CoreML will decided where the model is loaded.
|
||||||
|
*/
|
||||||
all = 2;
|
all = 2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The CPU & Neural Engine combined into a singular compute unit.
|
||||||
|
*/
|
||||||
cpu_and_neural_engine = 3;
|
cpu_and_neural_engine = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoadModelRequest {
|
/**
|
||||||
string model_name = 1;
|
* Represents information about an available model.
|
||||||
ComputeUnits compute_units = 2;
|
* The primary key of a model is it's 'name' field.
|
||||||
bool reduce_memory = 3;
|
*/
|
||||||
|
message ModelInfo {
|
||||||
|
/**
|
||||||
|
* The name of the available model. Note that within the context of a single RPC server,
|
||||||
|
* the name of a model is a unique identifier. This may not be true when utilizing a cluster or
|
||||||
|
* load balanced server, so keep that in mind.
|
||||||
|
*/
|
||||||
|
string name = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The attention of the model. Model attention determines what compute units can be used to
|
||||||
|
* load the model and make predictions.
|
||||||
|
*/
|
||||||
|
ModelAttention attention = 2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether the model is currently loaded onto an available compute unit.
|
||||||
|
*/
|
||||||
|
bool is_loaded = 3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The compute unit that the model is currently loaded into, if it is loaded to one at all.
|
||||||
|
* When is_loaded is false, the value of this field should be null.
|
||||||
|
*/
|
||||||
|
ComputeUnits loaded_compute_units = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The compute units that this model supports using.
|
||||||
|
*/
|
||||||
|
repeated ComputeUnits supported_compute_units = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the format of an image.
|
||||||
|
*/
|
||||||
|
enum ImageFormat {
|
||||||
|
/**
|
||||||
|
* The PNG image format.
|
||||||
|
*/
|
||||||
|
png = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents an image within the Stable Diffusion context.
|
||||||
|
* This could be an input image for an image generation request, or it could be
|
||||||
|
* a generated image from the Stable Diffusion model.
|
||||||
|
*/
|
||||||
|
message Image {
|
||||||
|
/**
|
||||||
|
* The format of the image.
|
||||||
|
*/
|
||||||
|
ImageFormat format = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The raw data of the image, in the specified format.
|
||||||
|
*/
|
||||||
|
bytes data = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a request to list the models available on the host.
|
||||||
|
*/
|
||||||
|
message ListModelsRequest {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a response to listing the models available on the host.
|
||||||
|
*/
|
||||||
|
message ListModelsResponse {
|
||||||
|
/**
|
||||||
|
* The available models on the Stable Diffusion server.
|
||||||
|
*/
|
||||||
|
repeated ModelInfo available_models = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a request to load a model into a specified compute unit.
|
||||||
|
*/
|
||||||
|
message LoadModelRequest {
|
||||||
|
/**
|
||||||
|
* The model name to load onto the compute unit.
|
||||||
|
*/
|
||||||
|
string model_name = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The compute units to load the model onto.
|
||||||
|
*/
|
||||||
|
ComputeUnits compute_units = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a response to loading a model.
|
||||||
|
*/
|
||||||
message LoadModelResponse {}
|
message LoadModelResponse {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model service, for management and loading of models.
|
||||||
|
*/
|
||||||
service ModelService {
|
service ModelService {
|
||||||
|
/**
|
||||||
|
* Lists the available models on the host.
|
||||||
|
* This will return both models that are currently loaded, and models that are not yet loaded.
|
||||||
|
*/
|
||||||
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
|
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
|
||||||
rpc ReloadModels(ReloadModelsRequest) returns (ReloadModelsResponse);
|
|
||||||
|
/**
|
||||||
|
* Loads a model onto a compute unit.
|
||||||
|
*/
|
||||||
rpc LoadModel(LoadModelRequest) returns (LoadModelResponse);
|
rpc LoadModel(LoadModelRequest) returns (LoadModelResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a request to generate images using a loaded model.
|
||||||
|
*/
|
||||||
message GenerateImagesRequest {
|
message GenerateImagesRequest {
|
||||||
|
/**
|
||||||
|
* The model name to use for generation.
|
||||||
|
* The model must be already be loaded using ModelService.LoadModel RPC method.
|
||||||
|
*/
|
||||||
string model_name = 1;
|
string model_name = 1;
|
||||||
uint32 image_count = 2;
|
|
||||||
string prompt = 3;
|
/**
|
||||||
string negative_prompt = 4;
|
* The output format for generated images.
|
||||||
|
*/
|
||||||
|
ImageFormat output_image_format = 2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of batches of images to generate.
|
||||||
|
*/
|
||||||
|
uint32 batch_count = 3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of images inside a single batch.
|
||||||
|
*/
|
||||||
|
uint32 batch_size = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The positive textual prompt for image generation.
|
||||||
|
*/
|
||||||
|
string prompt = 5;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The negative prompt for image generation.
|
||||||
|
*/
|
||||||
|
string negative_prompt = 6;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The random seed to use.
|
||||||
|
* Zero indicates that the seed should be random.
|
||||||
|
*/
|
||||||
|
uint32 seed = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the response from image generation.
|
||||||
|
*/
|
||||||
message GenerateImagesResponse {
|
message GenerateImagesResponse {
|
||||||
|
/**
|
||||||
|
* The set of generated images by the Stable Diffusion pipeline.
|
||||||
|
*/
|
||||||
repeated Image images = 1;
|
repeated Image images = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The seeds that were used to generate the images.
|
||||||
|
*/
|
||||||
|
repeated uint32 seeds = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The image generation service, for generating images from loaded models.
|
||||||
|
*/
|
||||||
service ImageGenerationService {
|
service ImageGenerationService {
|
||||||
rpc GenerateImage(GenerateImagesRequest) returns (GenerateImagesResponse);
|
/**
|
||||||
|
* Generates images using a loaded model.
|
||||||
|
*/
|
||||||
|
rpc GenerateImages(GenerateImagesRequest) returns (GenerateImagesResponse);
|
||||||
}
|
}
|
||||||
|
@ -9,23 +9,25 @@ let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port
|
|||||||
Task { @MainActor in
|
Task { @MainActor in
|
||||||
do {
|
do {
|
||||||
let modelListResponse = try await client.modelService.listModels(.init())
|
let modelListResponse = try await client.modelService.listModels(.init())
|
||||||
print("Loading model...")
|
print("Loading random model...")
|
||||||
let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }!
|
let modelInfo = modelListResponse.availableModels.randomElement()!
|
||||||
_ = try await client.modelService.loadModel(.with { request in
|
_ = try await client.modelService.loadModel(.with { request in
|
||||||
request.modelName = modelInfo.name
|
request.modelName = modelInfo.name
|
||||||
})
|
})
|
||||||
print("Loaded model.")
|
print("Loaded random model.")
|
||||||
|
|
||||||
print("Generating image...")
|
print("Generating image...")
|
||||||
let request = SdGenerateImagesRequest.with {
|
let request = SdGenerateImagesRequest.with {
|
||||||
$0.modelName = modelInfo.name
|
$0.modelName = modelInfo.name
|
||||||
|
$0.outputImageFormat = .png
|
||||||
$0.prompt = "cat"
|
$0.prompt = "cat"
|
||||||
$0.imageCount = 1
|
$0.batchCount = 1
|
||||||
|
$0.batchSize = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = try await client.imageGenerationService.generateImage(request)
|
let response = try await client.imageGenerationService.generateImages(request)
|
||||||
let image = response.images.first!
|
let image = response.images.first!
|
||||||
try image.content.write(to: URL(filePath: "output.png"))
|
try image.data.write(to: URL(filePath: "output.png"))
|
||||||
print("Generated image to output.png")
|
print("Generated image to output.png")
|
||||||
exit(0)
|
exit(0)
|
||||||
} catch {
|
} catch {
|
||||||
|
@ -2,6 +2,6 @@ import Foundation
|
|||||||
|
|
||||||
public enum SdCoreError: Error {
|
public enum SdCoreError: Error {
|
||||||
case modelNotLoaded
|
case modelNotLoaded
|
||||||
case imageEncode
|
case imageEncodeFailed
|
||||||
case modelNotFound
|
case modelNotFound
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,38 @@
|
|||||||
import CoreImage
|
import CoreImage
|
||||||
import Foundation
|
import Foundation
|
||||||
|
import StableDiffusionProtos
|
||||||
import UniformTypeIdentifiers
|
import UniformTypeIdentifiers
|
||||||
|
|
||||||
extension CGImage {
|
extension CGImage {
|
||||||
func toPngData() throws -> Data {
|
func toImageData(format: SdImageFormat) throws -> Data {
|
||||||
guard let data = CFDataCreateMutable(nil, 0) else {
|
guard let data = CFDataCreateMutable(nil, 0) else {
|
||||||
throw SdCoreError.imageEncode
|
throw SdCoreError.imageEncodeFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
guard let destination = CGImageDestinationCreateWithData(data, "public.png" as CFString, 1, nil) else {
|
guard let destination = try CGImageDestinationCreateWithData(data, formatToTypeIdentifier(format) as CFString, 1, nil) else {
|
||||||
throw SdCoreError.imageEncode
|
throw SdCoreError.imageEncodeFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
CGImageDestinationAddImage(destination, self, nil)
|
CGImageDestinationAddImage(destination, self, nil)
|
||||||
if CGImageDestinationFinalize(destination) {
|
if CGImageDestinationFinalize(destination) {
|
||||||
return data as Data
|
return data as Data
|
||||||
} else {
|
} else {
|
||||||
throw SdCoreError.imageEncode
|
throw SdCoreError.imageEncodeFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSdImage(format: SdImageFormat) throws -> SdImage {
|
||||||
|
let content = try toImageData(format: format)
|
||||||
|
var image = SdImage()
|
||||||
|
image.format = format
|
||||||
|
image.data = content
|
||||||
|
return image
|
||||||
|
}
|
||||||
|
|
||||||
|
private func formatToTypeIdentifier(_ format: SdImageFormat) throws -> String {
|
||||||
|
switch format {
|
||||||
|
case .png: return "public.png"
|
||||||
|
default: throw SdCoreError.imageEncodeFailed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ public actor ModelManager {
|
|||||||
self.modelBaseURL = modelBaseURL
|
self.modelBaseURL = modelBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
public func reloadModels() throws {
|
public func reloadAvailableModels() throws {
|
||||||
modelInfos.removeAll()
|
modelInfos.removeAll()
|
||||||
modelUrls.removeAll()
|
modelUrls.removeAll()
|
||||||
modelStates.removeAll()
|
modelStates.removeAll()
|
||||||
@ -26,8 +26,37 @@ public actor ModelManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func listModels() -> [SdModelInfo] {
|
public func listAvailableModels() async throws -> [SdModelInfo] {
|
||||||
Array(modelInfos.values)
|
var results: [SdModelInfo] = []
|
||||||
|
for simpleInfo in modelInfos.values {
|
||||||
|
var info = try SdModelInfo(jsonString: simpleInfo.jsonString())
|
||||||
|
if let maybeLoaded = modelStates[info.name] {
|
||||||
|
info.isLoaded = await maybeLoaded.isModelLoaded()
|
||||||
|
if let loadedComputeUnits = await maybeLoaded.loadedModelComputeUnits() {
|
||||||
|
info.loadedComputeUnits = loadedComputeUnits
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info.isLoaded = false
|
||||||
|
info.loadedComputeUnits = .init()
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.attention == .splitEinSum {
|
||||||
|
info.supportedComputeUnits = [
|
||||||
|
.cpuAndGpu,
|
||||||
|
.cpuAndNeuralEngine,
|
||||||
|
.cpu,
|
||||||
|
.all
|
||||||
|
]
|
||||||
|
} else {
|
||||||
|
info.supportedComputeUnits = [
|
||||||
|
.cpuAndGpu,
|
||||||
|
.cpu
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
results.append(info)
|
||||||
|
}
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
public func createModelState(name: String) throws -> ModelState {
|
public func createModelState(name: String) throws -> ModelState {
|
||||||
@ -53,13 +82,14 @@ public actor ModelManager {
|
|||||||
private func addModel(url: URL) throws {
|
private func addModel(url: URL) throws {
|
||||||
var info = SdModelInfo()
|
var info = SdModelInfo()
|
||||||
info.name = url.lastPathComponent
|
info.name = url.lastPathComponent
|
||||||
let attention = getModelAttention(url)
|
if let attention = getModelAttention(url) {
|
||||||
info.attention = attention ?? "unknown"
|
info.attention = attention
|
||||||
|
}
|
||||||
modelInfos[info.name] = info
|
modelInfos[info.name] = info
|
||||||
modelUrls[info.name] = url
|
modelUrls[info.name] = url
|
||||||
}
|
}
|
||||||
|
|
||||||
private func getModelAttention(_ url: URL) -> String? {
|
private func getModelAttention(_ url: URL) -> SdModelAttention? {
|
||||||
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
let unetMetadataURL = url.appending(components: "Unet.mlmodelc", "metadata.json")
|
||||||
|
|
||||||
struct ModelMetadata: Decodable {
|
struct ModelMetadata: Decodable {
|
||||||
@ -74,7 +104,7 @@ public actor ModelManager {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? "split-einsum" : "original"
|
return metadatas[0].mlProgramOperationTypeHistogram["Ios16.einsum"] != nil ? SdModelAttention.splitEinSum : SdModelAttention.original
|
||||||
} catch {
|
} catch {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -7,14 +7,15 @@ public actor ModelState {
|
|||||||
private let url: URL
|
private let url: URL
|
||||||
private var pipeline: StableDiffusionPipeline?
|
private var pipeline: StableDiffusionPipeline?
|
||||||
private var tokenizer: BPETokenizer?
|
private var tokenizer: BPETokenizer?
|
||||||
|
private var loadedConfiguration: MLModelConfiguration?
|
||||||
|
|
||||||
public init(url: URL) {
|
public init(url: URL) {
|
||||||
self.url = url
|
self.url = url
|
||||||
}
|
}
|
||||||
|
|
||||||
public func load() throws {
|
public func load(request: SdLoadModelRequest) throws {
|
||||||
let config = MLModelConfiguration()
|
let config = MLModelConfiguration()
|
||||||
config.computeUnits = .cpuAndGPU
|
config.computeUnits = request.computeUnits.toMlComputeUnits()
|
||||||
pipeline = try StableDiffusionPipeline(
|
pipeline = try StableDiffusionPipeline(
|
||||||
resourcesAt: url,
|
resourcesAt: url,
|
||||||
controlNet: [],
|
controlNet: [],
|
||||||
@ -26,6 +27,15 @@ public actor ModelState {
|
|||||||
let vocabUrl = url.appending(component: "vocab.json")
|
let vocabUrl = url.appending(component: "vocab.json")
|
||||||
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
||||||
try pipeline?.loadResources()
|
try pipeline?.loadResources()
|
||||||
|
loadedConfiguration = config
|
||||||
|
}
|
||||||
|
|
||||||
|
public func isModelLoaded() -> Bool {
|
||||||
|
pipeline != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
public func loadedModelComputeUnits() -> SdComputeUnits? {
|
||||||
|
loadedConfiguration?.computeUnits.toSdComputeUnits()
|
||||||
}
|
}
|
||||||
|
|
||||||
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||||
@ -33,20 +43,25 @@ public actor ModelState {
|
|||||||
throw SdCoreError.modelNotLoaded
|
throw SdCoreError.modelNotLoaded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let baseSeed: UInt32 = request.seed
|
||||||
|
|
||||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||||
pipelineConfig.negativePrompt = request.negativePrompt
|
pipelineConfig.negativePrompt = request.negativePrompt
|
||||||
pipelineConfig.seed = UInt32.random(in: 0 ..< UInt32.max)
|
pipelineConfig.imageCount = Int(request.batchSize)
|
||||||
|
|
||||||
var response = SdGenerateImagesResponse()
|
var response = SdGenerateImagesResponse()
|
||||||
for _ in 0 ..< request.imageCount {
|
for _ in 0 ..< request.batchCount {
|
||||||
|
var seed = baseSeed
|
||||||
|
if seed == 0 {
|
||||||
|
seed = UInt32.random(in: 0 ..< UInt32.max)
|
||||||
|
}
|
||||||
|
pipelineConfig.seed = seed
|
||||||
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
let images = try pipeline.generateImages(configuration: pipelineConfig)
|
||||||
|
|
||||||
for cgImage in images {
|
for cgImage in images {
|
||||||
guard let cgImage else { continue }
|
guard let cgImage else { continue }
|
||||||
var image = SdImage()
|
try response.images.append(cgImage.toSdImage(format: request.outputImageFormat))
|
||||||
image.content = try cgImage.toPngData()
|
|
||||||
response.images.append(image)
|
|
||||||
}
|
}
|
||||||
|
response.seeds.append(seed)
|
||||||
}
|
}
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,9 @@ import NIOConcurrencyHelpers
|
|||||||
import SwiftProtobuf
|
import SwiftProtobuf
|
||||||
|
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model service, for management and loading of models.
|
||||||
|
///
|
||||||
/// Usage: instantiate `SdModelServiceClient`, then call methods of this protocol to make API calls.
|
/// Usage: instantiate `SdModelServiceClient`, then call methods of this protocol to make API calls.
|
||||||
public protocol SdModelServiceClientProtocol: GRPCClient {
|
public protocol SdModelServiceClientProtocol: GRPCClient {
|
||||||
var serviceName: String { get }
|
var serviceName: String { get }
|
||||||
@ -36,11 +39,6 @@ public protocol SdModelServiceClientProtocol: GRPCClient {
|
|||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
) -> UnaryCall<SdListModelsRequest, SdListModelsResponse>
|
) -> UnaryCall<SdListModelsRequest, SdListModelsResponse>
|
||||||
|
|
||||||
func reloadModels(
|
|
||||||
_ request: SdReloadModelsRequest,
|
|
||||||
callOptions: CallOptions?
|
|
||||||
) -> UnaryCall<SdReloadModelsRequest, SdReloadModelsResponse>
|
|
||||||
|
|
||||||
func loadModel(
|
func loadModel(
|
||||||
_ request: SdLoadModelRequest,
|
_ request: SdLoadModelRequest,
|
||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
@ -52,7 +50,9 @@ extension SdModelServiceClientProtocol {
|
|||||||
return "gay.pizza.stable.diffusion.ModelService"
|
return "gay.pizza.stable.diffusion.ModelService"
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary call to ListModels
|
///*
|
||||||
|
/// Lists the available models on the host.
|
||||||
|
/// This will return both models that are currently loaded, and models that are not yet loaded.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - request: Request to send to ListModels.
|
/// - request: Request to send to ListModels.
|
||||||
@ -70,25 +70,8 @@ extension SdModelServiceClientProtocol {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary call to ReloadModels
|
///*
|
||||||
///
|
/// Loads a model onto a compute unit.
|
||||||
/// - Parameters:
|
|
||||||
/// - request: Request to send to ReloadModels.
|
|
||||||
/// - callOptions: Call options.
|
|
||||||
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
|
||||||
public func reloadModels(
|
|
||||||
_ request: SdReloadModelsRequest,
|
|
||||||
callOptions: CallOptions? = nil
|
|
||||||
) -> UnaryCall<SdReloadModelsRequest, SdReloadModelsResponse> {
|
|
||||||
return self.makeUnaryCall(
|
|
||||||
path: SdModelServiceClientMetadata.Methods.reloadModels.path,
|
|
||||||
request: request,
|
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
|
||||||
interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? []
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unary call to LoadModel
|
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - request: Request to send to LoadModel.
|
/// - request: Request to send to LoadModel.
|
||||||
@ -167,6 +150,8 @@ public struct SdModelServiceNIOClient: SdModelServiceClientProtocol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if compiler(>=5.6)
|
#if compiler(>=5.6)
|
||||||
|
///*
|
||||||
|
/// The model service, for management and loading of models.
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
public protocol SdModelServiceAsyncClientProtocol: GRPCClient {
|
public protocol SdModelServiceAsyncClientProtocol: GRPCClient {
|
||||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
@ -177,11 +162,6 @@ public protocol SdModelServiceAsyncClientProtocol: GRPCClient {
|
|||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
) -> GRPCAsyncUnaryCall<SdListModelsRequest, SdListModelsResponse>
|
) -> GRPCAsyncUnaryCall<SdListModelsRequest, SdListModelsResponse>
|
||||||
|
|
||||||
func makeReloadModelsCall(
|
|
||||||
_ request: SdReloadModelsRequest,
|
|
||||||
callOptions: CallOptions?
|
|
||||||
) -> GRPCAsyncUnaryCall<SdReloadModelsRequest, SdReloadModelsResponse>
|
|
||||||
|
|
||||||
func makeLoadModelCall(
|
func makeLoadModelCall(
|
||||||
_ request: SdLoadModelRequest,
|
_ request: SdLoadModelRequest,
|
||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
@ -210,18 +190,6 @@ extension SdModelServiceAsyncClientProtocol {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func makeReloadModelsCall(
|
|
||||||
_ request: SdReloadModelsRequest,
|
|
||||||
callOptions: CallOptions? = nil
|
|
||||||
) -> GRPCAsyncUnaryCall<SdReloadModelsRequest, SdReloadModelsResponse> {
|
|
||||||
return self.makeAsyncUnaryCall(
|
|
||||||
path: SdModelServiceClientMetadata.Methods.reloadModels.path,
|
|
||||||
request: request,
|
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
|
||||||
interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? []
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func makeLoadModelCall(
|
public func makeLoadModelCall(
|
||||||
_ request: SdLoadModelRequest,
|
_ request: SdLoadModelRequest,
|
||||||
callOptions: CallOptions? = nil
|
callOptions: CallOptions? = nil
|
||||||
@ -249,18 +217,6 @@ extension SdModelServiceAsyncClientProtocol {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func reloadModels(
|
|
||||||
_ request: SdReloadModelsRequest,
|
|
||||||
callOptions: CallOptions? = nil
|
|
||||||
) async throws -> SdReloadModelsResponse {
|
|
||||||
return try await self.performAsyncUnaryCall(
|
|
||||||
path: SdModelServiceClientMetadata.Methods.reloadModels.path,
|
|
||||||
request: request,
|
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
|
||||||
interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? []
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func loadModel(
|
public func loadModel(
|
||||||
_ request: SdLoadModelRequest,
|
_ request: SdLoadModelRequest,
|
||||||
callOptions: CallOptions? = nil
|
callOptions: CallOptions? = nil
|
||||||
@ -298,9 +254,6 @@ public protocol SdModelServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
|||||||
/// - Returns: Interceptors to use when invoking 'listModels'.
|
/// - Returns: Interceptors to use when invoking 'listModels'.
|
||||||
func makeListModelsInterceptors() -> [ClientInterceptor<SdListModelsRequest, SdListModelsResponse>]
|
func makeListModelsInterceptors() -> [ClientInterceptor<SdListModelsRequest, SdListModelsResponse>]
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when invoking 'reloadModels'.
|
|
||||||
func makeReloadModelsInterceptors() -> [ClientInterceptor<SdReloadModelsRequest, SdReloadModelsResponse>]
|
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when invoking 'loadModel'.
|
/// - Returns: Interceptors to use when invoking 'loadModel'.
|
||||||
func makeLoadModelInterceptors() -> [ClientInterceptor<SdLoadModelRequest, SdLoadModelResponse>]
|
func makeLoadModelInterceptors() -> [ClientInterceptor<SdLoadModelRequest, SdLoadModelResponse>]
|
||||||
}
|
}
|
||||||
@ -311,7 +264,6 @@ public enum SdModelServiceClientMetadata {
|
|||||||
fullName: "gay.pizza.stable.diffusion.ModelService",
|
fullName: "gay.pizza.stable.diffusion.ModelService",
|
||||||
methods: [
|
methods: [
|
||||||
SdModelServiceClientMetadata.Methods.listModels,
|
SdModelServiceClientMetadata.Methods.listModels,
|
||||||
SdModelServiceClientMetadata.Methods.reloadModels,
|
|
||||||
SdModelServiceClientMetadata.Methods.loadModel,
|
SdModelServiceClientMetadata.Methods.loadModel,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -323,12 +275,6 @@ public enum SdModelServiceClientMetadata {
|
|||||||
type: GRPCCallType.unary
|
type: GRPCCallType.unary
|
||||||
)
|
)
|
||||||
|
|
||||||
public static let reloadModels = GRPCMethodDescriptor(
|
|
||||||
name: "ReloadModels",
|
|
||||||
path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels",
|
|
||||||
type: GRPCCallType.unary
|
|
||||||
)
|
|
||||||
|
|
||||||
public static let loadModel = GRPCMethodDescriptor(
|
public static let loadModel = GRPCMethodDescriptor(
|
||||||
name: "LoadModel",
|
name: "LoadModel",
|
||||||
path: "/gay.pizza.stable.diffusion.ModelService/LoadModel",
|
path: "/gay.pizza.stable.diffusion.ModelService/LoadModel",
|
||||||
@ -337,12 +283,15 @@ public enum SdModelServiceClientMetadata {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The image generation service, for generating images from loaded models.
|
||||||
|
///
|
||||||
/// Usage: instantiate `SdImageGenerationServiceClient`, then call methods of this protocol to make API calls.
|
/// Usage: instantiate `SdImageGenerationServiceClient`, then call methods of this protocol to make API calls.
|
||||||
public protocol SdImageGenerationServiceClientProtocol: GRPCClient {
|
public protocol SdImageGenerationServiceClientProtocol: GRPCClient {
|
||||||
var serviceName: String { get }
|
var serviceName: String { get }
|
||||||
var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get }
|
var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
func generateImage(
|
func generateImages(
|
||||||
_ request: SdGenerateImagesRequest,
|
_ request: SdGenerateImagesRequest,
|
||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
||||||
@ -353,21 +302,22 @@ extension SdImageGenerationServiceClientProtocol {
|
|||||||
return "gay.pizza.stable.diffusion.ImageGenerationService"
|
return "gay.pizza.stable.diffusion.ImageGenerationService"
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary call to GenerateImage
|
///*
|
||||||
|
/// Generates images using a loaded model.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - request: Request to send to GenerateImage.
|
/// - request: Request to send to GenerateImages.
|
||||||
/// - callOptions: Call options.
|
/// - callOptions: Call options.
|
||||||
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
||||||
public func generateImage(
|
public func generateImages(
|
||||||
_ request: SdGenerateImagesRequest,
|
_ request: SdGenerateImagesRequest,
|
||||||
callOptions: CallOptions? = nil
|
callOptions: CallOptions? = nil
|
||||||
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse> {
|
) -> UnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse> {
|
||||||
return self.makeUnaryCall(
|
return self.makeUnaryCall(
|
||||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path,
|
path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path,
|
||||||
request: request,
|
request: request,
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? []
|
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -432,12 +382,14 @@ public struct SdImageGenerationServiceNIOClient: SdImageGenerationServiceClientP
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if compiler(>=5.6)
|
#if compiler(>=5.6)
|
||||||
|
///*
|
||||||
|
/// The image generation service, for generating images from loaded models.
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient {
|
public protocol SdImageGenerationServiceAsyncClientProtocol: GRPCClient {
|
||||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get }
|
var interceptors: SdImageGenerationServiceClientInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
func makeGenerateImageCall(
|
func makeGenerateImagesCall(
|
||||||
_ request: SdGenerateImagesRequest,
|
_ request: SdGenerateImagesRequest,
|
||||||
callOptions: CallOptions?
|
callOptions: CallOptions?
|
||||||
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse>
|
||||||
@ -453,30 +405,30 @@ extension SdImageGenerationServiceAsyncClientProtocol {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
public func makeGenerateImageCall(
|
public func makeGenerateImagesCall(
|
||||||
_ request: SdGenerateImagesRequest,
|
_ request: SdGenerateImagesRequest,
|
||||||
callOptions: CallOptions? = nil
|
callOptions: CallOptions? = nil
|
||||||
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse> {
|
) -> GRPCAsyncUnaryCall<SdGenerateImagesRequest, SdGenerateImagesResponse> {
|
||||||
return self.makeAsyncUnaryCall(
|
return self.makeAsyncUnaryCall(
|
||||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path,
|
path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path,
|
||||||
request: request,
|
request: request,
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? []
|
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
extension SdImageGenerationServiceAsyncClientProtocol {
|
extension SdImageGenerationServiceAsyncClientProtocol {
|
||||||
public func generateImage(
|
public func generateImages(
|
||||||
_ request: SdGenerateImagesRequest,
|
_ request: SdGenerateImagesRequest,
|
||||||
callOptions: CallOptions? = nil
|
callOptions: CallOptions? = nil
|
||||||
) async throws -> SdGenerateImagesResponse {
|
) async throws -> SdGenerateImagesResponse {
|
||||||
return try await self.performAsyncUnaryCall(
|
return try await self.performAsyncUnaryCall(
|
||||||
path: SdImageGenerationServiceClientMetadata.Methods.generateImage.path,
|
path: SdImageGenerationServiceClientMetadata.Methods.generateImages.path,
|
||||||
request: request,
|
request: request,
|
||||||
callOptions: callOptions ?? self.defaultCallOptions,
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? []
|
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? []
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -502,8 +454,8 @@ public struct SdImageGenerationServiceAsyncClient: SdImageGenerationServiceAsync
|
|||||||
|
|
||||||
public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
public protocol SdImageGenerationServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when invoking 'generateImage'.
|
/// - Returns: Interceptors to use when invoking 'generateImages'.
|
||||||
func makeGenerateImageInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
func makeGenerateImagesInterceptors() -> [ClientInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum SdImageGenerationServiceClientMetadata {
|
public enum SdImageGenerationServiceClientMetadata {
|
||||||
@ -511,27 +463,33 @@ public enum SdImageGenerationServiceClientMetadata {
|
|||||||
name: "ImageGenerationService",
|
name: "ImageGenerationService",
|
||||||
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
||||||
methods: [
|
methods: [
|
||||||
SdImageGenerationServiceClientMetadata.Methods.generateImage,
|
SdImageGenerationServiceClientMetadata.Methods.generateImages,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
public enum Methods {
|
public enum Methods {
|
||||||
public static let generateImage = GRPCMethodDescriptor(
|
public static let generateImages = GRPCMethodDescriptor(
|
||||||
name: "GenerateImage",
|
name: "GenerateImages",
|
||||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage",
|
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
|
||||||
type: GRPCCallType.unary
|
type: GRPCCallType.unary
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model service, for management and loading of models.
|
||||||
|
///
|
||||||
/// To build a server, implement a class that conforms to this protocol.
|
/// To build a server, implement a class that conforms to this protocol.
|
||||||
public protocol SdModelServiceProvider: CallHandlerProvider {
|
public protocol SdModelServiceProvider: CallHandlerProvider {
|
||||||
var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get }
|
var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Lists the available models on the host.
|
||||||
|
/// This will return both models that are currently loaded, and models that are not yet loaded.
|
||||||
func listModels(request: SdListModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdListModelsResponse>
|
func listModels(request: SdListModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdListModelsResponse>
|
||||||
|
|
||||||
func reloadModels(request: SdReloadModelsRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdReloadModelsResponse>
|
///*
|
||||||
|
/// Loads a model onto a compute unit.
|
||||||
func loadModel(request: SdLoadModelRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdLoadModelResponse>
|
func loadModel(request: SdLoadModelRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdLoadModelResponse>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,15 +514,6 @@ extension SdModelServiceProvider {
|
|||||||
userFunction: self.listModels(request:context:)
|
userFunction: self.listModels(request:context:)
|
||||||
)
|
)
|
||||||
|
|
||||||
case "ReloadModels":
|
|
||||||
return UnaryServerHandler(
|
|
||||||
context: context,
|
|
||||||
requestDeserializer: ProtobufDeserializer<SdReloadModelsRequest>(),
|
|
||||||
responseSerializer: ProtobufSerializer<SdReloadModelsResponse>(),
|
|
||||||
interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [],
|
|
||||||
userFunction: self.reloadModels(request:context:)
|
|
||||||
)
|
|
||||||
|
|
||||||
case "LoadModel":
|
case "LoadModel":
|
||||||
return UnaryServerHandler(
|
return UnaryServerHandler(
|
||||||
context: context,
|
context: context,
|
||||||
@ -582,22 +531,25 @@ extension SdModelServiceProvider {
|
|||||||
|
|
||||||
#if compiler(>=5.6)
|
#if compiler(>=5.6)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model service, for management and loading of models.
|
||||||
|
///
|
||||||
/// To implement a server, implement an object which conforms to this protocol.
|
/// To implement a server, implement an object which conforms to this protocol.
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
public protocol SdModelServiceAsyncProvider: CallHandlerProvider {
|
public protocol SdModelServiceAsyncProvider: CallHandlerProvider {
|
||||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get }
|
var interceptors: SdModelServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Lists the available models on the host.
|
||||||
|
/// This will return both models that are currently loaded, and models that are not yet loaded.
|
||||||
@Sendable func listModels(
|
@Sendable func listModels(
|
||||||
request: SdListModelsRequest,
|
request: SdListModelsRequest,
|
||||||
context: GRPCAsyncServerCallContext
|
context: GRPCAsyncServerCallContext
|
||||||
) async throws -> SdListModelsResponse
|
) async throws -> SdListModelsResponse
|
||||||
|
|
||||||
@Sendable func reloadModels(
|
///*
|
||||||
request: SdReloadModelsRequest,
|
/// Loads a model onto a compute unit.
|
||||||
context: GRPCAsyncServerCallContext
|
|
||||||
) async throws -> SdReloadModelsResponse
|
|
||||||
|
|
||||||
@Sendable func loadModel(
|
@Sendable func loadModel(
|
||||||
request: SdLoadModelRequest,
|
request: SdLoadModelRequest,
|
||||||
context: GRPCAsyncServerCallContext
|
context: GRPCAsyncServerCallContext
|
||||||
@ -632,15 +584,6 @@ extension SdModelServiceAsyncProvider {
|
|||||||
wrapping: self.listModels(request:context:)
|
wrapping: self.listModels(request:context:)
|
||||||
)
|
)
|
||||||
|
|
||||||
case "ReloadModels":
|
|
||||||
return GRPCAsyncServerHandler(
|
|
||||||
context: context,
|
|
||||||
requestDeserializer: ProtobufDeserializer<SdReloadModelsRequest>(),
|
|
||||||
responseSerializer: ProtobufSerializer<SdReloadModelsResponse>(),
|
|
||||||
interceptors: self.interceptors?.makeReloadModelsInterceptors() ?? [],
|
|
||||||
wrapping: self.reloadModels(request:context:)
|
|
||||||
)
|
|
||||||
|
|
||||||
case "LoadModel":
|
case "LoadModel":
|
||||||
return GRPCAsyncServerHandler(
|
return GRPCAsyncServerHandler(
|
||||||
context: context,
|
context: context,
|
||||||
@ -664,10 +607,6 @@ public protocol SdModelServiceServerInterceptorFactoryProtocol {
|
|||||||
/// Defaults to calling `self.makeInterceptors()`.
|
/// Defaults to calling `self.makeInterceptors()`.
|
||||||
func makeListModelsInterceptors() -> [ServerInterceptor<SdListModelsRequest, SdListModelsResponse>]
|
func makeListModelsInterceptors() -> [ServerInterceptor<SdListModelsRequest, SdListModelsResponse>]
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when handling 'reloadModels'.
|
|
||||||
/// Defaults to calling `self.makeInterceptors()`.
|
|
||||||
func makeReloadModelsInterceptors() -> [ServerInterceptor<SdReloadModelsRequest, SdReloadModelsResponse>]
|
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when handling 'loadModel'.
|
/// - Returns: Interceptors to use when handling 'loadModel'.
|
||||||
/// Defaults to calling `self.makeInterceptors()`.
|
/// Defaults to calling `self.makeInterceptors()`.
|
||||||
func makeLoadModelInterceptors() -> [ServerInterceptor<SdLoadModelRequest, SdLoadModelResponse>]
|
func makeLoadModelInterceptors() -> [ServerInterceptor<SdLoadModelRequest, SdLoadModelResponse>]
|
||||||
@ -679,7 +618,6 @@ public enum SdModelServiceServerMetadata {
|
|||||||
fullName: "gay.pizza.stable.diffusion.ModelService",
|
fullName: "gay.pizza.stable.diffusion.ModelService",
|
||||||
methods: [
|
methods: [
|
||||||
SdModelServiceServerMetadata.Methods.listModels,
|
SdModelServiceServerMetadata.Methods.listModels,
|
||||||
SdModelServiceServerMetadata.Methods.reloadModels,
|
|
||||||
SdModelServiceServerMetadata.Methods.loadModel,
|
SdModelServiceServerMetadata.Methods.loadModel,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -691,12 +629,6 @@ public enum SdModelServiceServerMetadata {
|
|||||||
type: GRPCCallType.unary
|
type: GRPCCallType.unary
|
||||||
)
|
)
|
||||||
|
|
||||||
public static let reloadModels = GRPCMethodDescriptor(
|
|
||||||
name: "ReloadModels",
|
|
||||||
path: "/gay.pizza.stable.diffusion.ModelService/ReloadModels",
|
|
||||||
type: GRPCCallType.unary
|
|
||||||
)
|
|
||||||
|
|
||||||
public static let loadModel = GRPCMethodDescriptor(
|
public static let loadModel = GRPCMethodDescriptor(
|
||||||
name: "LoadModel",
|
name: "LoadModel",
|
||||||
path: "/gay.pizza.stable.diffusion.ModelService/LoadModel",
|
path: "/gay.pizza.stable.diffusion.ModelService/LoadModel",
|
||||||
@ -704,11 +636,16 @@ public enum SdModelServiceServerMetadata {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
///*
|
||||||
|
/// The image generation service, for generating images from loaded models.
|
||||||
|
///
|
||||||
/// To build a server, implement a class that conforms to this protocol.
|
/// To build a server, implement a class that conforms to this protocol.
|
||||||
public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
|
public protocol SdImageGenerationServiceProvider: CallHandlerProvider {
|
||||||
var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get }
|
var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
func generateImage(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
///*
|
||||||
|
/// Generates images using a loaded model.
|
||||||
|
func generateImages(request: SdGenerateImagesRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdGenerateImagesResponse>
|
||||||
}
|
}
|
||||||
|
|
||||||
extension SdImageGenerationServiceProvider {
|
extension SdImageGenerationServiceProvider {
|
||||||
@ -723,13 +660,13 @@ extension SdImageGenerationServiceProvider {
|
|||||||
context: CallHandlerContext
|
context: CallHandlerContext
|
||||||
) -> GRPCServerHandlerProtocol? {
|
) -> GRPCServerHandlerProtocol? {
|
||||||
switch name {
|
switch name {
|
||||||
case "GenerateImage":
|
case "GenerateImages":
|
||||||
return UnaryServerHandler(
|
return UnaryServerHandler(
|
||||||
context: context,
|
context: context,
|
||||||
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
||||||
responseSerializer: ProtobufSerializer<SdGenerateImagesResponse>(),
|
responseSerializer: ProtobufSerializer<SdGenerateImagesResponse>(),
|
||||||
interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [],
|
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [],
|
||||||
userFunction: self.generateImage(request:context:)
|
userFunction: self.generateImages(request:context:)
|
||||||
)
|
)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -740,13 +677,18 @@ extension SdImageGenerationServiceProvider {
|
|||||||
|
|
||||||
#if compiler(>=5.6)
|
#if compiler(>=5.6)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The image generation service, for generating images from loaded models.
|
||||||
|
///
|
||||||
/// To implement a server, implement an object which conforms to this protocol.
|
/// To implement a server, implement an object which conforms to this protocol.
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
|
public protocol SdImageGenerationServiceAsyncProvider: CallHandlerProvider {
|
||||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get }
|
var interceptors: SdImageGenerationServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
@Sendable func generateImage(
|
///*
|
||||||
|
/// Generates images using a loaded model.
|
||||||
|
@Sendable func generateImages(
|
||||||
request: SdGenerateImagesRequest,
|
request: SdGenerateImagesRequest,
|
||||||
context: GRPCAsyncServerCallContext
|
context: GRPCAsyncServerCallContext
|
||||||
) async throws -> SdGenerateImagesResponse
|
) async throws -> SdGenerateImagesResponse
|
||||||
@ -771,13 +713,13 @@ extension SdImageGenerationServiceAsyncProvider {
|
|||||||
context: CallHandlerContext
|
context: CallHandlerContext
|
||||||
) -> GRPCServerHandlerProtocol? {
|
) -> GRPCServerHandlerProtocol? {
|
||||||
switch name {
|
switch name {
|
||||||
case "GenerateImage":
|
case "GenerateImages":
|
||||||
return GRPCAsyncServerHandler(
|
return GRPCAsyncServerHandler(
|
||||||
context: context,
|
context: context,
|
||||||
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
requestDeserializer: ProtobufDeserializer<SdGenerateImagesRequest>(),
|
||||||
responseSerializer: ProtobufSerializer<SdGenerateImagesResponse>(),
|
responseSerializer: ProtobufSerializer<SdGenerateImagesResponse>(),
|
||||||
interceptors: self.interceptors?.makeGenerateImageInterceptors() ?? [],
|
interceptors: self.interceptors?.makeGenerateImagesInterceptors() ?? [],
|
||||||
wrapping: self.generateImage(request:context:)
|
wrapping: self.generateImages(request:context:)
|
||||||
)
|
)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -790,9 +732,9 @@ extension SdImageGenerationServiceAsyncProvider {
|
|||||||
|
|
||||||
public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol {
|
public protocol SdImageGenerationServiceServerInterceptorFactoryProtocol {
|
||||||
|
|
||||||
/// - Returns: Interceptors to use when handling 'generateImage'.
|
/// - Returns: Interceptors to use when handling 'generateImages'.
|
||||||
/// Defaults to calling `self.makeInterceptors()`.
|
/// Defaults to calling `self.makeInterceptors()`.
|
||||||
func makeGenerateImageInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
func makeGenerateImagesInterceptors() -> [ServerInterceptor<SdGenerateImagesRequest, SdGenerateImagesResponse>]
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum SdImageGenerationServiceServerMetadata {
|
public enum SdImageGenerationServiceServerMetadata {
|
||||||
@ -800,14 +742,14 @@ public enum SdImageGenerationServiceServerMetadata {
|
|||||||
name: "ImageGenerationService",
|
name: "ImageGenerationService",
|
||||||
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
fullName: "gay.pizza.stable.diffusion.ImageGenerationService",
|
||||||
methods: [
|
methods: [
|
||||||
SdImageGenerationServiceServerMetadata.Methods.generateImage,
|
SdImageGenerationServiceServerMetadata.Methods.generateImages,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
public enum Methods {
|
public enum Methods {
|
||||||
public static let generateImage = GRPCMethodDescriptor(
|
public static let generateImages = GRPCMethodDescriptor(
|
||||||
name: "GenerateImage",
|
name: "GenerateImages",
|
||||||
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImage",
|
path: "/gay.pizza.stable.diffusion.ImageGenerationService/GenerateImages",
|
||||||
type: GRPCCallType.unary
|
type: GRPCCallType.unary
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
// For information on using the generated types, please see the documentation:
|
// For information on using the generated types, please see the documentation:
|
||||||
// https://github.com/apple/swift-protobuf/
|
// https://github.com/apple/swift-protobuf/
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Stable Diffusion RPC service for Apple Platforms.
|
||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import SwiftProtobuf
|
import SwiftProtobuf
|
||||||
|
|
||||||
@ -20,9 +23,67 @@ fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAP
|
|||||||
typealias Version = _2
|
typealias Version = _2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents the model attention. Model attention has to do with how the model is encoded, and
|
||||||
|
/// can determine what compute units are able to support a particular model.
|
||||||
|
public enum SdModelAttention: SwiftProtobuf.Enum {
|
||||||
|
public typealias RawValue = Int
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model is an original attention type. It can be loaded only onto CPU & GPU compute units.
|
||||||
|
case original // = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model is a split-ein-sum attention type. It can be loaded onto all compute units,
|
||||||
|
/// including the Apple Neural Engine.
|
||||||
|
case splitEinSum // = 1
|
||||||
|
case UNRECOGNIZED(Int)
|
||||||
|
|
||||||
|
public init() {
|
||||||
|
self = .original
|
||||||
|
}
|
||||||
|
|
||||||
|
public init?(rawValue: Int) {
|
||||||
|
switch rawValue {
|
||||||
|
case 0: self = .original
|
||||||
|
case 1: self = .splitEinSum
|
||||||
|
default: self = .UNRECOGNIZED(rawValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public var rawValue: Int {
|
||||||
|
switch self {
|
||||||
|
case .original: return 0
|
||||||
|
case .splitEinSum: return 1
|
||||||
|
case .UNRECOGNIZED(let i): return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#if swift(>=4.2)
|
||||||
|
|
||||||
|
extension SdModelAttention: CaseIterable {
|
||||||
|
// The compiler won't synthesize support with the UNRECOGNIZED case.
|
||||||
|
public static var allCases: [SdModelAttention] = [
|
||||||
|
.original,
|
||||||
|
.splitEinSum,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // swift(>=4.2)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents the schedulers that are used to sample images.
|
||||||
public enum SdScheduler: SwiftProtobuf.Enum {
|
public enum SdScheduler: SwiftProtobuf.Enum {
|
||||||
public typealias RawValue = Int
|
public typealias RawValue = Int
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The PNDM (Pseudo numerical methods for diffusion models) scheduler.
|
||||||
case pndm // = 0
|
case pndm // = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The DPM-Solver++ scheduler.
|
||||||
case dpmSolverPlusPlus // = 1
|
case dpmSolverPlusPlus // = 1
|
||||||
case UNRECOGNIZED(Int)
|
case UNRECOGNIZED(Int)
|
||||||
|
|
||||||
@ -60,11 +121,25 @@ extension SdScheduler: CaseIterable {
|
|||||||
|
|
||||||
#endif // swift(>=4.2)
|
#endif // swift(>=4.2)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a specifier for what compute units are available for ML tasks.
|
||||||
public enum SdComputeUnits: SwiftProtobuf.Enum {
|
public enum SdComputeUnits: SwiftProtobuf.Enum {
|
||||||
public typealias RawValue = Int
|
public typealias RawValue = Int
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The CPU as a singular compute unit.
|
||||||
case cpu // = 0
|
case cpu // = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The CPU & GPU combined into a singular compute unit.
|
||||||
case cpuAndGpu // = 1
|
case cpuAndGpu // = 1
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Allow the usage of all compute units. CoreML will decided where the model is loaded.
|
||||||
case all // = 2
|
case all // = 2
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The CPU & Neural Engine combined into a singular compute unit.
|
||||||
case cpuAndNeuralEngine // = 3
|
case cpuAndNeuralEngine // = 3
|
||||||
case UNRECOGNIZED(Int)
|
case UNRECOGNIZED(Int)
|
||||||
|
|
||||||
@ -108,34 +183,108 @@ extension SdComputeUnits: CaseIterable {
|
|||||||
|
|
||||||
#endif // swift(>=4.2)
|
#endif // swift(>=4.2)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents the format of an image.
|
||||||
|
public enum SdImageFormat: SwiftProtobuf.Enum {
|
||||||
|
public typealias RawValue = Int
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The PNG image format.
|
||||||
|
case png // = 0
|
||||||
|
case UNRECOGNIZED(Int)
|
||||||
|
|
||||||
|
public init() {
|
||||||
|
self = .png
|
||||||
|
}
|
||||||
|
|
||||||
|
public init?(rawValue: Int) {
|
||||||
|
switch rawValue {
|
||||||
|
case 0: self = .png
|
||||||
|
default: self = .UNRECOGNIZED(rawValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public var rawValue: Int {
|
||||||
|
switch self {
|
||||||
|
case .png: return 0
|
||||||
|
case .UNRECOGNIZED(let i): return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#if swift(>=4.2)
|
||||||
|
|
||||||
|
extension SdImageFormat: CaseIterable {
|
||||||
|
// The compiler won't synthesize support with the UNRECOGNIZED case.
|
||||||
|
public static var allCases: [SdImageFormat] = [
|
||||||
|
.png,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // swift(>=4.2)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents information about an available model.
|
||||||
|
/// The primary key of a model is it's 'name' field.
|
||||||
public struct SdModelInfo {
|
public struct SdModelInfo {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The name of the available model. Note that within the context of a single RPC server,
|
||||||
|
/// the name of a model is a unique identifier. This may not be true when utilizing a cluster or
|
||||||
|
/// load balanced server, so keep that in mind.
|
||||||
public var name: String = String()
|
public var name: String = String()
|
||||||
|
|
||||||
public var attention: String = String()
|
///*
|
||||||
|
/// The attention of the model. Model attention determines what compute units can be used to
|
||||||
|
/// load the model and make predictions.
|
||||||
|
public var attention: SdModelAttention = .original
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Whether the model is currently loaded onto an available compute unit.
|
||||||
public var isLoaded: Bool = false
|
public var isLoaded: Bool = false
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The compute unit that the model is currently loaded into, if it is loaded to one at all.
|
||||||
|
/// When is_loaded is false, the value of this field should be null.
|
||||||
|
public var loadedComputeUnits: SdComputeUnits = .cpu
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The compute units that this model supports using.
|
||||||
|
public var supportedComputeUnits: [SdComputeUnits] = []
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents an image within the Stable Diffusion context.
|
||||||
|
/// This could be an input image for an image generation request, or it could be
|
||||||
|
/// a generated image from the Stable Diffusion model.
|
||||||
public struct SdImage {
|
public struct SdImage {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
public var content: Data = Data()
|
///*
|
||||||
|
/// The format of the image.
|
||||||
|
public var format: SdImageFormat = .png
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The raw data of the image, in the specified format.
|
||||||
|
public var data: Data = Data()
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a request to list the models available on the host.
|
||||||
public struct SdListModelsRequest {
|
public struct SdListModelsRequest {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
@ -146,54 +295,44 @@ public struct SdListModelsRequest {
|
|||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a response to listing the models available on the host.
|
||||||
public struct SdListModelsResponse {
|
public struct SdListModelsResponse {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
public var models: [SdModelInfo] = []
|
///*
|
||||||
|
/// The available models on the Stable Diffusion server.
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var availableModels: [SdModelInfo] = []
|
||||||
|
|
||||||
public init() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
public struct SdReloadModelsRequest {
|
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
|
||||||
// methods supported on all messages.
|
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
|
||||||
|
|
||||||
public init() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
public struct SdReloadModelsResponse {
|
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
|
||||||
// methods supported on all messages.
|
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a request to load a model into a specified compute unit.
|
||||||
public struct SdLoadModelRequest {
|
public struct SdLoadModelRequest {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model name to load onto the compute unit.
|
||||||
public var modelName: String = String()
|
public var modelName: String = String()
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The compute units to load the model onto.
|
||||||
public var computeUnits: SdComputeUnits = .cpu
|
public var computeUnits: SdComputeUnits = .cpu
|
||||||
|
|
||||||
public var reduceMemory: Bool = false
|
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a response to loading a model.
|
||||||
public struct SdLoadModelResponse {
|
public struct SdLoadModelResponse {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
@ -204,45 +343,77 @@ public struct SdLoadModelResponse {
|
|||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a request to generate images using a loaded model.
|
||||||
public struct SdGenerateImagesRequest {
|
public struct SdGenerateImagesRequest {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The model name to use for generation.
|
||||||
|
/// The model must be already be loaded using ModelService.LoadModel RPC method.
|
||||||
public var modelName: String = String()
|
public var modelName: String = String()
|
||||||
|
|
||||||
public var imageCount: UInt32 = 0
|
///*
|
||||||
|
/// The output format for generated images.
|
||||||
|
public var outputImageFormat: SdImageFormat = .png
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The number of batches of images to generate.
|
||||||
|
public var batchCount: UInt32 = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The number of images inside a single batch.
|
||||||
|
public var batchSize: UInt32 = 0
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The positive textual prompt for image generation.
|
||||||
public var prompt: String = String()
|
public var prompt: String = String()
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The negative prompt for image generation.
|
||||||
public var negativePrompt: String = String()
|
public var negativePrompt: String = String()
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The random seed to use.
|
||||||
|
/// Zero indicates that the seed should be random.
|
||||||
|
public var seed: UInt32 = 0
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents the response from image generation.
|
||||||
public struct SdGenerateImagesResponse {
|
public struct SdGenerateImagesResponse {
|
||||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
// methods supported on all messages.
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The set of generated images by the Stable Diffusion pipeline.
|
||||||
public var images: [SdImage] = []
|
public var images: [SdImage] = []
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The seeds that were used to generate the images.
|
||||||
|
public var seeds: [UInt32] = []
|
||||||
|
|
||||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if swift(>=5.5) && canImport(_Concurrency)
|
#if swift(>=5.5) && canImport(_Concurrency)
|
||||||
|
extension SdModelAttention: @unchecked Sendable {}
|
||||||
extension SdScheduler: @unchecked Sendable {}
|
extension SdScheduler: @unchecked Sendable {}
|
||||||
extension SdComputeUnits: @unchecked Sendable {}
|
extension SdComputeUnits: @unchecked Sendable {}
|
||||||
|
extension SdImageFormat: @unchecked Sendable {}
|
||||||
extension SdModelInfo: @unchecked Sendable {}
|
extension SdModelInfo: @unchecked Sendable {}
|
||||||
extension SdImage: @unchecked Sendable {}
|
extension SdImage: @unchecked Sendable {}
|
||||||
extension SdListModelsRequest: @unchecked Sendable {}
|
extension SdListModelsRequest: @unchecked Sendable {}
|
||||||
extension SdListModelsResponse: @unchecked Sendable {}
|
extension SdListModelsResponse: @unchecked Sendable {}
|
||||||
extension SdReloadModelsRequest: @unchecked Sendable {}
|
|
||||||
extension SdReloadModelsResponse: @unchecked Sendable {}
|
|
||||||
extension SdLoadModelRequest: @unchecked Sendable {}
|
extension SdLoadModelRequest: @unchecked Sendable {}
|
||||||
extension SdLoadModelResponse: @unchecked Sendable {}
|
extension SdLoadModelResponse: @unchecked Sendable {}
|
||||||
extension SdGenerateImagesRequest: @unchecked Sendable {}
|
extension SdGenerateImagesRequest: @unchecked Sendable {}
|
||||||
@ -253,10 +424,17 @@ extension SdGenerateImagesResponse: @unchecked Sendable {}
|
|||||||
|
|
||||||
fileprivate let _protobuf_package = "gay.pizza.stable.diffusion"
|
fileprivate let _protobuf_package = "gay.pizza.stable.diffusion"
|
||||||
|
|
||||||
|
extension SdModelAttention: SwiftProtobuf._ProtoNameProviding {
|
||||||
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
|
0: .same(proto: "original"),
|
||||||
|
1: .same(proto: "split_ein_sum"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
extension SdScheduler: SwiftProtobuf._ProtoNameProviding {
|
extension SdScheduler: SwiftProtobuf._ProtoNameProviding {
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
0: .same(proto: "pndm"),
|
0: .same(proto: "pndm"),
|
||||||
1: .same(proto: "dpmSolverPlusPlus"),
|
1: .same(proto: "dpm_solver_plus_plus"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,12 +447,20 @@ extension SdComputeUnits: SwiftProtobuf._ProtoNameProviding {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extension SdImageFormat: SwiftProtobuf._ProtoNameProviding {
|
||||||
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
|
0: .same(proto: "png"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||||
public static let protoMessageName: String = _protobuf_package + ".ModelInfo"
|
public static let protoMessageName: String = _protobuf_package + ".ModelInfo"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .same(proto: "name"),
|
1: .same(proto: "name"),
|
||||||
2: .same(proto: "attention"),
|
2: .same(proto: "attention"),
|
||||||
3: .standard(proto: "is_loaded"),
|
3: .standard(proto: "is_loaded"),
|
||||||
|
4: .standard(proto: "loaded_compute_units"),
|
||||||
|
5: .standard(proto: "supported_compute_units"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -284,8 +470,10 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
|||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeSingularStringField(value: &self.name) }()
|
case 1: try { try decoder.decodeSingularStringField(value: &self.name) }()
|
||||||
case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }()
|
case 2: try { try decoder.decodeSingularEnumField(value: &self.attention) }()
|
||||||
case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }()
|
case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }()
|
||||||
|
case 4: try { try decoder.decodeSingularEnumField(value: &self.loadedComputeUnits) }()
|
||||||
|
case 5: try { try decoder.decodeRepeatedEnumField(value: &self.supportedComputeUnits) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -295,12 +483,18 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
|||||||
if !self.name.isEmpty {
|
if !self.name.isEmpty {
|
||||||
try visitor.visitSingularStringField(value: self.name, fieldNumber: 1)
|
try visitor.visitSingularStringField(value: self.name, fieldNumber: 1)
|
||||||
}
|
}
|
||||||
if !self.attention.isEmpty {
|
if self.attention != .original {
|
||||||
try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2)
|
try visitor.visitSingularEnumField(value: self.attention, fieldNumber: 2)
|
||||||
}
|
}
|
||||||
if self.isLoaded != false {
|
if self.isLoaded != false {
|
||||||
try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3)
|
try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3)
|
||||||
}
|
}
|
||||||
|
if self.loadedComputeUnits != .cpu {
|
||||||
|
try visitor.visitSingularEnumField(value: self.loadedComputeUnits, fieldNumber: 4)
|
||||||
|
}
|
||||||
|
if !self.supportedComputeUnits.isEmpty {
|
||||||
|
try visitor.visitPackedEnumField(value: self.supportedComputeUnits, fieldNumber: 5)
|
||||||
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,6 +502,8 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
|||||||
if lhs.name != rhs.name {return false}
|
if lhs.name != rhs.name {return false}
|
||||||
if lhs.attention != rhs.attention {return false}
|
if lhs.attention != rhs.attention {return false}
|
||||||
if lhs.isLoaded != rhs.isLoaded {return false}
|
if lhs.isLoaded != rhs.isLoaded {return false}
|
||||||
|
if lhs.loadedComputeUnits != rhs.loadedComputeUnits {return false}
|
||||||
|
if lhs.supportedComputeUnits != rhs.supportedComputeUnits {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -316,7 +512,8 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
|||||||
extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||||
public static let protoMessageName: String = _protobuf_package + ".Image"
|
public static let protoMessageName: String = _protobuf_package + ".Image"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .same(proto: "content"),
|
1: .same(proto: "format"),
|
||||||
|
2: .same(proto: "data"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -325,21 +522,26 @@ extension SdImage: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBa
|
|||||||
// allocates stack space for every case branch when no optimizations are
|
// allocates stack space for every case branch when no optimizations are
|
||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeSingularBytesField(value: &self.content) }()
|
case 1: try { try decoder.decodeSingularEnumField(value: &self.format) }()
|
||||||
|
case 2: try { try decoder.decodeSingularBytesField(value: &self.data) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||||
if !self.content.isEmpty {
|
if self.format != .png {
|
||||||
try visitor.visitSingularBytesField(value: self.content, fieldNumber: 1)
|
try visitor.visitSingularEnumField(value: self.format, fieldNumber: 1)
|
||||||
|
}
|
||||||
|
if !self.data.isEmpty {
|
||||||
|
try visitor.visitSingularBytesField(value: self.data, fieldNumber: 2)
|
||||||
}
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdImage, rhs: SdImage) -> Bool {
|
public static func ==(lhs: SdImage, rhs: SdImage) -> Bool {
|
||||||
if lhs.content != rhs.content {return false}
|
if lhs.format != rhs.format {return false}
|
||||||
|
if lhs.data != rhs.data {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -367,7 +569,7 @@ extension SdListModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImpl
|
|||||||
extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||||
public static let protoMessageName: String = _protobuf_package + ".ListModelsResponse"
|
public static let protoMessageName: String = _protobuf_package + ".ListModelsResponse"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .same(proto: "models"),
|
1: .standard(proto: "available_models"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -376,59 +578,21 @@ extension SdListModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImp
|
|||||||
// allocates stack space for every case branch when no optimizations are
|
// allocates stack space for every case branch when no optimizations are
|
||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.models) }()
|
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.availableModels) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||||
if !self.models.isEmpty {
|
if !self.availableModels.isEmpty {
|
||||||
try visitor.visitRepeatedMessageField(value: self.models, fieldNumber: 1)
|
try visitor.visitRepeatedMessageField(value: self.availableModels, fieldNumber: 1)
|
||||||
}
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdListModelsResponse, rhs: SdListModelsResponse) -> Bool {
|
public static func ==(lhs: SdListModelsResponse, rhs: SdListModelsResponse) -> Bool {
|
||||||
if lhs.models != rhs.models {return false}
|
if lhs.availableModels != rhs.availableModels {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
extension SdReloadModelsRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
|
||||||
public static let protoMessageName: String = _protobuf_package + ".ReloadModelsRequest"
|
|
||||||
public static let _protobuf_nameMap = SwiftProtobuf._NameMap()
|
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
|
||||||
while let _ = try decoder.nextFieldNumber() {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
|
||||||
}
|
|
||||||
|
|
||||||
public static func ==(lhs: SdReloadModelsRequest, rhs: SdReloadModelsRequest) -> Bool {
|
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
extension SdReloadModelsResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
|
||||||
public static let protoMessageName: String = _protobuf_package + ".ReloadModelsResponse"
|
|
||||||
public static let _protobuf_nameMap = SwiftProtobuf._NameMap()
|
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
|
||||||
while let _ = try decoder.nextFieldNumber() {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
|
||||||
}
|
|
||||||
|
|
||||||
public static func ==(lhs: SdReloadModelsResponse, rhs: SdReloadModelsResponse) -> Bool {
|
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -439,7 +603,6 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple
|
|||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .standard(proto: "model_name"),
|
1: .standard(proto: "model_name"),
|
||||||
2: .standard(proto: "compute_units"),
|
2: .standard(proto: "compute_units"),
|
||||||
3: .standard(proto: "reduce_memory"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -450,7 +613,6 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple
|
|||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
||||||
case 2: try { try decoder.decodeSingularEnumField(value: &self.computeUnits) }()
|
case 2: try { try decoder.decodeSingularEnumField(value: &self.computeUnits) }()
|
||||||
case 3: try { try decoder.decodeSingularBoolField(value: &self.reduceMemory) }()
|
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -463,16 +625,12 @@ extension SdLoadModelRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImple
|
|||||||
if self.computeUnits != .cpu {
|
if self.computeUnits != .cpu {
|
||||||
try visitor.visitSingularEnumField(value: self.computeUnits, fieldNumber: 2)
|
try visitor.visitSingularEnumField(value: self.computeUnits, fieldNumber: 2)
|
||||||
}
|
}
|
||||||
if self.reduceMemory != false {
|
|
||||||
try visitor.visitSingularBoolField(value: self.reduceMemory, fieldNumber: 3)
|
|
||||||
}
|
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdLoadModelRequest, rhs: SdLoadModelRequest) -> Bool {
|
public static func ==(lhs: SdLoadModelRequest, rhs: SdLoadModelRequest) -> Bool {
|
||||||
if lhs.modelName != rhs.modelName {return false}
|
if lhs.modelName != rhs.modelName {return false}
|
||||||
if lhs.computeUnits != rhs.computeUnits {return false}
|
if lhs.computeUnits != rhs.computeUnits {return false}
|
||||||
if lhs.reduceMemory != rhs.reduceMemory {return false}
|
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -501,9 +659,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesRequest"
|
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesRequest"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .standard(proto: "model_name"),
|
1: .standard(proto: "model_name"),
|
||||||
2: .standard(proto: "image_count"),
|
2: .standard(proto: "output_image_format"),
|
||||||
3: .same(proto: "prompt"),
|
3: .standard(proto: "batch_count"),
|
||||||
4: .standard(proto: "negative_prompt"),
|
4: .standard(proto: "batch_size"),
|
||||||
|
5: .same(proto: "prompt"),
|
||||||
|
6: .standard(proto: "negative_prompt"),
|
||||||
|
7: .same(proto: "seed"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -513,9 +674,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
||||||
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.imageCount) }()
|
case 2: try { try decoder.decodeSingularEnumField(value: &self.outputImageFormat) }()
|
||||||
case 3: try { try decoder.decodeSingularStringField(value: &self.prompt) }()
|
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.batchCount) }()
|
||||||
case 4: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }()
|
case 4: try { try decoder.decodeSingularUInt32Field(value: &self.batchSize) }()
|
||||||
|
case 5: try { try decoder.decodeSingularStringField(value: &self.prompt) }()
|
||||||
|
case 6: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }()
|
||||||
|
case 7: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -525,23 +689,35 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
|||||||
if !self.modelName.isEmpty {
|
if !self.modelName.isEmpty {
|
||||||
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
||||||
}
|
}
|
||||||
if self.imageCount != 0 {
|
if self.outputImageFormat != .png {
|
||||||
try visitor.visitSingularUInt32Field(value: self.imageCount, fieldNumber: 2)
|
try visitor.visitSingularEnumField(value: self.outputImageFormat, fieldNumber: 2)
|
||||||
|
}
|
||||||
|
if self.batchCount != 0 {
|
||||||
|
try visitor.visitSingularUInt32Field(value: self.batchCount, fieldNumber: 3)
|
||||||
|
}
|
||||||
|
if self.batchSize != 0 {
|
||||||
|
try visitor.visitSingularUInt32Field(value: self.batchSize, fieldNumber: 4)
|
||||||
}
|
}
|
||||||
if !self.prompt.isEmpty {
|
if !self.prompt.isEmpty {
|
||||||
try visitor.visitSingularStringField(value: self.prompt, fieldNumber: 3)
|
try visitor.visitSingularStringField(value: self.prompt, fieldNumber: 5)
|
||||||
}
|
}
|
||||||
if !self.negativePrompt.isEmpty {
|
if !self.negativePrompt.isEmpty {
|
||||||
try visitor.visitSingularStringField(value: self.negativePrompt, fieldNumber: 4)
|
try visitor.visitSingularStringField(value: self.negativePrompt, fieldNumber: 6)
|
||||||
|
}
|
||||||
|
if self.seed != 0 {
|
||||||
|
try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 7)
|
||||||
}
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdGenerateImagesRequest, rhs: SdGenerateImagesRequest) -> Bool {
|
public static func ==(lhs: SdGenerateImagesRequest, rhs: SdGenerateImagesRequest) -> Bool {
|
||||||
if lhs.modelName != rhs.modelName {return false}
|
if lhs.modelName != rhs.modelName {return false}
|
||||||
if lhs.imageCount != rhs.imageCount {return false}
|
if lhs.outputImageFormat != rhs.outputImageFormat {return false}
|
||||||
|
if lhs.batchCount != rhs.batchCount {return false}
|
||||||
|
if lhs.batchSize != rhs.batchSize {return false}
|
||||||
if lhs.prompt != rhs.prompt {return false}
|
if lhs.prompt != rhs.prompt {return false}
|
||||||
if lhs.negativePrompt != rhs.negativePrompt {return false}
|
if lhs.negativePrompt != rhs.negativePrompt {return false}
|
||||||
|
if lhs.seed != rhs.seed {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -551,6 +727,7 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag
|
|||||||
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesResponse"
|
public static let protoMessageName: String = _protobuf_package + ".GenerateImagesResponse"
|
||||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
1: .same(proto: "images"),
|
1: .same(proto: "images"),
|
||||||
|
2: .same(proto: "seeds"),
|
||||||
]
|
]
|
||||||
|
|
||||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
@ -560,6 +737,7 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag
|
|||||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
switch fieldNumber {
|
switch fieldNumber {
|
||||||
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
|
case 1: try { try decoder.decodeRepeatedMessageField(value: &self.images) }()
|
||||||
|
case 2: try { try decoder.decodeRepeatedUInt32Field(value: &self.seeds) }()
|
||||||
default: break
|
default: break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -569,11 +747,15 @@ extension SdGenerateImagesResponse: SwiftProtobuf.Message, SwiftProtobuf._Messag
|
|||||||
if !self.images.isEmpty {
|
if !self.images.isEmpty {
|
||||||
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1)
|
try visitor.visitRepeatedMessageField(value: self.images, fieldNumber: 1)
|
||||||
}
|
}
|
||||||
|
if !self.seeds.isEmpty {
|
||||||
|
try visitor.visitPackedUInt32Field(value: self.seeds, fieldNumber: 2)
|
||||||
|
}
|
||||||
try unknownFields.traverse(visitor: &visitor)
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func ==(lhs: SdGenerateImagesResponse, rhs: SdGenerateImagesResponse) -> Bool {
|
public static func ==(lhs: SdGenerateImagesResponse, rhs: SdGenerateImagesResponse) -> Bool {
|
||||||
if lhs.images != rhs.images {return false}
|
if lhs.images != rhs.images {return false}
|
||||||
|
if lhs.seeds != rhs.seeds {return false}
|
||||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
26
Sources/StableDiffusionProtos/Utilities.swift
Normal file
26
Sources/StableDiffusionProtos/Utilities.swift
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import CoreML
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
public extension SdComputeUnits {
|
||||||
|
func toMlComputeUnits() -> MLComputeUnits {
|
||||||
|
switch self {
|
||||||
|
case .all: return .all
|
||||||
|
case .cpu: return .cpuOnly
|
||||||
|
case .cpuAndGpu: return .cpuAndGPU
|
||||||
|
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
|
||||||
|
default: return .all
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public extension MLComputeUnits {
|
||||||
|
func toSdComputeUnits() -> SdComputeUnits {
|
||||||
|
switch self {
|
||||||
|
case .all: return .all
|
||||||
|
case .cpuOnly: return .cpu
|
||||||
|
case .cpuAndGPU: return .cpuAndGpu
|
||||||
|
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
|
||||||
|
default: return .all
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
|||||||
self.modelManager = modelManager
|
self.modelManager = modelManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateImage(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
||||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||||
throw SdCoreError.modelNotFound
|
throw SdCoreError.modelNotFound
|
||||||
}
|
}
|
||||||
|
@ -11,20 +11,15 @@ class ModelServiceProvider: SdModelServiceAsyncProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func listModels(request _: SdListModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
|
func listModels(request _: SdListModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
|
||||||
let models = await modelManager.listModels()
|
let models = try await modelManager.listAvailableModels()
|
||||||
var response = SdListModelsResponse()
|
var response = SdListModelsResponse()
|
||||||
response.models.append(contentsOf: models)
|
response.availableModels.append(contentsOf: models)
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadModels(request _: SdReloadModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdReloadModelsResponse {
|
|
||||||
try await modelManager.reloadModels()
|
|
||||||
return SdReloadModelsResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
||||||
let state = try await modelManager.createModelState(name: request.modelName)
|
let state = try await modelManager.createModelState(name: request.modelName)
|
||||||
try await state.load()
|
try await state.load(request: request)
|
||||||
return SdLoadModelResponse()
|
return SdLoadModelResponse()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ struct ServerCommand: ParsableCommand {
|
|||||||
let semaphore = DispatchSemaphore(value: 0)
|
let semaphore = DispatchSemaphore(value: 0)
|
||||||
Task {
|
Task {
|
||||||
do {
|
do {
|
||||||
try await modelManager.reloadModels()
|
try await modelManager.reloadAvailableModels()
|
||||||
} catch {
|
} catch {
|
||||||
ServerCommand.exit(withError: error)
|
ServerCommand.exit(withError: error)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user