mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 05:30:54 +00:00
Support for starting images and many more parameters.
This commit is contained in:
@ -1,12 +1,16 @@
|
||||
package gay.pizza.stable.diffusion.sample
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import gay.pizza.stable.diffusion.StableDiffusion
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.Image
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.io.path.Path
|
||||
import kotlin.io.path.exists
|
||||
import kotlin.io.path.readBytes
|
||||
import kotlin.io.path.writeBytes
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
@ -41,6 +45,8 @@ fun main() {
|
||||
|
||||
println("generating images...")
|
||||
|
||||
val startingImagePath = Path("work/start.png")
|
||||
|
||||
val request = GenerateImagesRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
outputImageFormat = StableDiffusion.ImageFormat.png
|
||||
@ -48,6 +54,14 @@ fun main() {
|
||||
batchCount = 2
|
||||
prompt = "cat"
|
||||
negativePrompt = "bad, low quality, nsfw"
|
||||
if (startingImagePath.exists()) {
|
||||
val image = Image.newBuilder().apply {
|
||||
format = StableDiffusion.ImageFormat.png
|
||||
data = ByteString.copyFrom(startingImagePath.readBytes())
|
||||
}.build()
|
||||
|
||||
startingImage = image
|
||||
}
|
||||
}.build()
|
||||
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)
|
||||
|
||||
|
@ -219,6 +219,40 @@ message GenerateImagesRequest {
|
||||
* Zero indicates that the seed should be random.
|
||||
*/
|
||||
uint32 seed = 7;
|
||||
|
||||
/**
|
||||
* An optional starting image to use for generation.
|
||||
*/
|
||||
Image starting_image = 8;
|
||||
|
||||
/**
|
||||
* Indicates whether to enable the safety check network, if it is available.
|
||||
*/
|
||||
bool enable_safety_check = 9;
|
||||
|
||||
/**
|
||||
* The scheduler to use for generation.
|
||||
* The default is PNDM, if not specified.
|
||||
*/
|
||||
Scheduler scheduler = 10;
|
||||
|
||||
/**
|
||||
* The guidance scale, which controls the influence the prompt has on the image.
|
||||
* If not specified, a reasonable default value is used.
|
||||
*/
|
||||
float guidance_scale = 11;
|
||||
|
||||
/**
|
||||
* The strength of the image generation.
|
||||
* If not specified, a reasonable default value is used.
|
||||
*/
|
||||
float strength = 12;
|
||||
|
||||
/**
|
||||
* The number of inference steps to perform.
|
||||
* If not specified, a reasonable default value is used.
|
||||
*/
|
||||
uint32 step_count = 13;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3,5 +3,6 @@ import Foundation
|
||||
public enum SdCoreError: Error {
|
||||
case modelNotLoaded
|
||||
case imageEncodeFailed
|
||||
case imageDecodeFailed
|
||||
case modelNotFound
|
||||
}
|
||||
|
@ -36,3 +36,20 @@ extension CGImage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public extension SdImage {
|
||||
func toCgImage() throws -> CGImage {
|
||||
guard let dataProvider = CGDataProvider(data: data as CFData) else {
|
||||
throw SdCoreError.imageDecodeFailed
|
||||
}
|
||||
|
||||
if format == .png {
|
||||
guard let image = CGImage(pngDataProviderSource: dataProvider, decode: nil, shouldInterpolate: false, intent: .defaultIntent) else {
|
||||
throw SdCoreError.imageDecodeFailed
|
||||
}
|
||||
return image
|
||||
} else {
|
||||
throw SdCoreError.imageDecodeFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,6 +48,31 @@ public actor ModelState {
|
||||
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
|
||||
pipelineConfig.negativePrompt = request.negativePrompt
|
||||
pipelineConfig.imageCount = Int(request.batchSize)
|
||||
|
||||
if request.hasStartingImage {
|
||||
pipelineConfig.startingImage = try request.startingImage.toCgImage()
|
||||
}
|
||||
|
||||
if request.guidanceScale != 0.0 {
|
||||
pipelineConfig.guidanceScale = request.guidanceScale
|
||||
}
|
||||
|
||||
if request.stepCount != 0 {
|
||||
pipelineConfig.stepCount = Int(request.stepCount)
|
||||
}
|
||||
|
||||
if request.strength != 0.0 {
|
||||
pipelineConfig.strength = request.strength
|
||||
}
|
||||
|
||||
pipelineConfig.disableSafety = !request.enableSafetyCheck
|
||||
|
||||
switch request.scheduler {
|
||||
case .pndm: pipelineConfig.schedulerType = .pndmScheduler
|
||||
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
|
||||
default: pipelineConfig.schedulerType = .pndmScheduler
|
||||
}
|
||||
|
||||
var response = SdGenerateImagesResponse()
|
||||
for _ in 0 ..< request.batchCount {
|
||||
var seed = baseSeed
|
||||
|
@ -380,9 +380,46 @@ public struct SdGenerateImagesRequest {
|
||||
/// Zero indicates that the seed should be random.
|
||||
public var seed: UInt32 = 0
|
||||
|
||||
///*
|
||||
/// An optional starting image to use for generation.
|
||||
public var startingImage: SdImage {
|
||||
get {return _startingImage ?? SdImage()}
|
||||
set {_startingImage = newValue}
|
||||
}
|
||||
/// Returns true if `startingImage` has been explicitly set.
|
||||
public var hasStartingImage: Bool {return self._startingImage != nil}
|
||||
/// Clears the value of `startingImage`. Subsequent reads from it will return its default value.
|
||||
public mutating func clearStartingImage() {self._startingImage = nil}
|
||||
|
||||
///*
|
||||
/// Indicates whether to enable the safety check network, if it is available.
|
||||
public var enableSafetyCheck: Bool = false
|
||||
|
||||
///*
|
||||
/// The scheduler to use for generation.
|
||||
/// The default is PNDM, if not specified.
|
||||
public var scheduler: SdScheduler = .pndm
|
||||
|
||||
///*
|
||||
/// The guidance scale, which controls the influence the prompt has on the image.
|
||||
/// If not specified, a reasonable default value is used.
|
||||
public var guidanceScale: Float = 0
|
||||
|
||||
///*
|
||||
/// The strength of the image generation.
|
||||
/// If not specified, a reasonable default value is used.
|
||||
public var strength: Float = 0
|
||||
|
||||
///*
|
||||
/// The number of inference steps to perform.
|
||||
/// If not specified, a reasonable default value is used.
|
||||
public var stepCount: UInt32 = 0
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
|
||||
fileprivate var _startingImage: SdImage? = nil
|
||||
}
|
||||
|
||||
///*
|
||||
@ -665,6 +702,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
5: .same(proto: "prompt"),
|
||||
6: .standard(proto: "negative_prompt"),
|
||||
7: .same(proto: "seed"),
|
||||
8: .standard(proto: "starting_image"),
|
||||
9: .standard(proto: "enable_safety_check"),
|
||||
10: .same(proto: "scheduler"),
|
||||
11: .standard(proto: "guidance_scale"),
|
||||
12: .same(proto: "strength"),
|
||||
13: .standard(proto: "step_count"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
@ -680,12 +723,22 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
case 5: try { try decoder.decodeSingularStringField(value: &self.prompt) }()
|
||||
case 6: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }()
|
||||
case 7: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }()
|
||||
case 8: try { try decoder.decodeSingularMessageField(value: &self._startingImage) }()
|
||||
case 9: try { try decoder.decodeSingularBoolField(value: &self.enableSafetyCheck) }()
|
||||
case 10: try { try decoder.decodeSingularEnumField(value: &self.scheduler) }()
|
||||
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
|
||||
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
|
||||
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every if/case branch local when no optimizations
|
||||
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
|
||||
// https://github.com/apple/swift-protobuf/issues/1182
|
||||
if !self.modelName.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
||||
}
|
||||
@ -707,6 +760,24 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
if self.seed != 0 {
|
||||
try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 7)
|
||||
}
|
||||
try { if let v = self._startingImage {
|
||||
try visitor.visitSingularMessageField(value: v, fieldNumber: 8)
|
||||
} }()
|
||||
if self.enableSafetyCheck != false {
|
||||
try visitor.visitSingularBoolField(value: self.enableSafetyCheck, fieldNumber: 9)
|
||||
}
|
||||
if self.scheduler != .pndm {
|
||||
try visitor.visitSingularEnumField(value: self.scheduler, fieldNumber: 10)
|
||||
}
|
||||
if self.guidanceScale != 0 {
|
||||
try visitor.visitSingularFloatField(value: self.guidanceScale, fieldNumber: 11)
|
||||
}
|
||||
if self.strength != 0 {
|
||||
try visitor.visitSingularFloatField(value: self.strength, fieldNumber: 12)
|
||||
}
|
||||
if self.stepCount != 0 {
|
||||
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
@ -718,6 +789,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
|
||||
if lhs.prompt != rhs.prompt {return false}
|
||||
if lhs.negativePrompt != rhs.negativePrompt {return false}
|
||||
if lhs.seed != rhs.seed {return false}
|
||||
if lhs._startingImage != rhs._startingImage {return false}
|
||||
if lhs.enableSafetyCheck != rhs.enableSafetyCheck {return false}
|
||||
if lhs.scheduler != rhs.scheduler {return false}
|
||||
if lhs.guidanceScale != rhs.guidanceScale {return false}
|
||||
if lhs.strength != rhs.strength {return false}
|
||||
if lhs.stepCount != rhs.stepCount {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
|
@ -11,9 +11,14 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
|
||||
}
|
||||
|
||||
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
do {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
return try await state.generate(request)
|
||||
} catch {
|
||||
print(error)
|
||||
throw error
|
||||
}
|
||||
return try await state.generate(request)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user