diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index 844d8c5..aa44a07 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -1,12 +1,16 @@ package gay.pizza.stable.diffusion.sample +import com.google.protobuf.ByteString import gay.pizza.stable.diffusion.StableDiffusion import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest +import gay.pizza.stable.diffusion.StableDiffusion.Image import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest import gay.pizza.stable.diffusion.StableDiffusionRpcClient import io.grpc.ManagedChannelBuilder import kotlin.io.path.Path +import kotlin.io.path.exists +import kotlin.io.path.readBytes import kotlin.io.path.writeBytes import kotlin.system.exitProcess @@ -41,6 +45,8 @@ fun main() { println("generating images...") + val startingImagePath = Path("work/start.png") + val request = GenerateImagesRequest.newBuilder().apply { modelName = model.name outputImageFormat = StableDiffusion.ImageFormat.png @@ -48,6 +54,14 @@ fun main() { batchCount = 2 prompt = "cat" negativePrompt = "bad, low quality, nsfw" + if (startingImagePath.exists()) { + val image = Image.newBuilder().apply { + format = StableDiffusion.ImageFormat.png + data = ByteString.copyFrom(startingImagePath.readBytes()) + }.build() + + startingImage = image + } }.build() val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request) diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index 0b0d10b..50bdd25 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -219,6 +219,40 @@ message GenerateImagesRequest { * Zero indicates that the seed should be random. */ uint32 seed = 7; + + /** + * An optional starting image to use for generation. + */ + Image starting_image = 8; + + /** + * Indicates whether to enable the safety check network, if it is available. + */ + bool enable_safety_check = 9; + + /** + * The scheduler to use for generation. + * The default is PNDM, if not specified. + */ + Scheduler scheduler = 10; + + /** + * The guidance scale, which controls the influence the prompt has on the image. + * If not specified, a reasonable default value is used. + */ + float guidance_scale = 11; + + /** + * The strength of the image generation. + * If not specified, a reasonable default value is used. + */ + float strength = 12; + + /** + * The number of inference steps to perform. + * If not specified, a reasonable default value is used. + */ + uint32 step_count = 13; } /** diff --git a/Sources/StableDiffusionCore/Errors.swift b/Sources/StableDiffusionCore/Errors.swift index 6043444..4af0f8c 100644 --- a/Sources/StableDiffusionCore/Errors.swift +++ b/Sources/StableDiffusionCore/Errors.swift @@ -3,5 +3,6 @@ import Foundation public enum SdCoreError: Error { case modelNotLoaded case imageEncodeFailed + case imageDecodeFailed case modelNotFound } diff --git a/Sources/StableDiffusionCore/ImageExtensions.swift b/Sources/StableDiffusionCore/ImageExtensions.swift index fa70e0d..c079b0e 100644 --- a/Sources/StableDiffusionCore/ImageExtensions.swift +++ b/Sources/StableDiffusionCore/ImageExtensions.swift @@ -36,3 +36,20 @@ extension CGImage { } } } + +public extension SdImage { + func toCgImage() throws -> CGImage { + guard let dataProvider = CGDataProvider(data: data as CFData) else { + throw SdCoreError.imageDecodeFailed + } + + if format == .png { + guard let image = CGImage(pngDataProviderSource: dataProvider, decode: nil, shouldInterpolate: false, intent: .defaultIntent) else { + throw SdCoreError.imageDecodeFailed + } + return image + } else { + throw SdCoreError.imageDecodeFailed + } + } +} diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index ba27bcf..866e319 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -48,6 +48,31 @@ public actor ModelState { var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt) pipelineConfig.negativePrompt = request.negativePrompt pipelineConfig.imageCount = Int(request.batchSize) + + if request.hasStartingImage { + pipelineConfig.startingImage = try request.startingImage.toCgImage() + } + + if request.guidanceScale != 0.0 { + pipelineConfig.guidanceScale = request.guidanceScale + } + + if request.stepCount != 0 { + pipelineConfig.stepCount = Int(request.stepCount) + } + + if request.strength != 0.0 { + pipelineConfig.strength = request.strength + } + + pipelineConfig.disableSafety = !request.enableSafetyCheck + + switch request.scheduler { + case .pndm: pipelineConfig.schedulerType = .pndmScheduler + case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler + default: pipelineConfig.schedulerType = .pndmScheduler + } + var response = SdGenerateImagesResponse() for _ in 0 ..< request.batchCount { var seed = baseSeed diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index 25ea181..203c56c 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -380,9 +380,46 @@ public struct SdGenerateImagesRequest { /// Zero indicates that the seed should be random. public var seed: UInt32 = 0 + ///* + /// An optional starting image to use for generation. + public var startingImage: SdImage { + get {return _startingImage ?? SdImage()} + set {_startingImage = newValue} + } + /// Returns true if `startingImage` has been explicitly set. + public var hasStartingImage: Bool {return self._startingImage != nil} + /// Clears the value of `startingImage`. Subsequent reads from it will return its default value. + public mutating func clearStartingImage() {self._startingImage = nil} + + ///* + /// Indicates whether to enable the safety check network, if it is available. + public var enableSafetyCheck: Bool = false + + ///* + /// The scheduler to use for generation. + /// The default is PNDM, if not specified. + public var scheduler: SdScheduler = .pndm + + ///* + /// The guidance scale, which controls the influence the prompt has on the image. + /// If not specified, a reasonable default value is used. + public var guidanceScale: Float = 0 + + ///* + /// The strength of the image generation. + /// If not specified, a reasonable default value is used. + public var strength: Float = 0 + + ///* + /// The number of inference steps to perform. + /// If not specified, a reasonable default value is used. + public var stepCount: UInt32 = 0 + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} + + fileprivate var _startingImage: SdImage? = nil } ///* @@ -665,6 +702,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message 5: .same(proto: "prompt"), 6: .standard(proto: "negative_prompt"), 7: .same(proto: "seed"), + 8: .standard(proto: "starting_image"), + 9: .standard(proto: "enable_safety_check"), + 10: .same(proto: "scheduler"), + 11: .standard(proto: "guidance_scale"), + 12: .same(proto: "strength"), + 13: .standard(proto: "step_count"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -680,12 +723,22 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message case 5: try { try decoder.decodeSingularStringField(value: &self.prompt) }() case 6: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }() case 7: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }() + case 8: try { try decoder.decodeSingularMessageField(value: &self._startingImage) }() + case 9: try { try decoder.decodeSingularBoolField(value: &self.enableSafetyCheck) }() + case 10: try { try decoder.decodeSingularEnumField(value: &self.scheduler) }() + case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }() + case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }() + case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }() default: break } } } public func traverse(visitor: inout V) throws { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every if/case branch local when no optimizations + // are enabled. https://github.com/apple/swift-protobuf/issues/1034 and + // https://github.com/apple/swift-protobuf/issues/1182 if !self.modelName.isEmpty { try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1) } @@ -707,6 +760,24 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message if self.seed != 0 { try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 7) } + try { if let v = self._startingImage { + try visitor.visitSingularMessageField(value: v, fieldNumber: 8) + } }() + if self.enableSafetyCheck != false { + try visitor.visitSingularBoolField(value: self.enableSafetyCheck, fieldNumber: 9) + } + if self.scheduler != .pndm { + try visitor.visitSingularEnumField(value: self.scheduler, fieldNumber: 10) + } + if self.guidanceScale != 0 { + try visitor.visitSingularFloatField(value: self.guidanceScale, fieldNumber: 11) + } + if self.strength != 0 { + try visitor.visitSingularFloatField(value: self.strength, fieldNumber: 12) + } + if self.stepCount != 0 { + try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13) + } try unknownFields.traverse(visitor: &visitor) } @@ -718,6 +789,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message if lhs.prompt != rhs.prompt {return false} if lhs.negativePrompt != rhs.negativePrompt {return false} if lhs.seed != rhs.seed {return false} + if lhs._startingImage != rhs._startingImage {return false} + if lhs.enableSafetyCheck != rhs.enableSafetyCheck {return false} + if lhs.scheduler != rhs.scheduler {return false} + if lhs.guidanceScale != rhs.guidanceScale {return false} + if lhs.strength != rhs.strength {return false} + if lhs.stepCount != rhs.stepCount {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } diff --git a/Sources/StableDiffusionServer/ImageGenerationService.swift b/Sources/StableDiffusionServer/ImageGenerationService.swift index 993922a..b06eddd 100644 --- a/Sources/StableDiffusionServer/ImageGenerationService.swift +++ b/Sources/StableDiffusionServer/ImageGenerationService.swift @@ -11,9 +11,14 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider { } func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse { - guard let state = await modelManager.getModelState(name: request.modelName) else { - throw SdCoreError.modelNotFound + do { + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdCoreError.modelNotFound + } + return try await state.generate(request) + } catch { + print(error) + throw error } - return try await state.generate(request) } }