Support for starting images and many more parameters.

This commit is contained in:
2023-04-23 02:40:41 -07:00
parent 7c0b2779f4
commit d31e80bf4c
7 changed files with 176 additions and 3 deletions

View File

@ -1,12 +1,16 @@
package gay.pizza.stable.diffusion.sample
import com.google.protobuf.ByteString
import gay.pizza.stable.diffusion.StableDiffusion
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.Image
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder
import kotlin.io.path.Path
import kotlin.io.path.exists
import kotlin.io.path.readBytes
import kotlin.io.path.writeBytes
import kotlin.system.exitProcess
@ -41,6 +45,8 @@ fun main() {
println("generating images...")
val startingImagePath = Path("work/start.png")
val request = GenerateImagesRequest.newBuilder().apply {
modelName = model.name
outputImageFormat = StableDiffusion.ImageFormat.png
@ -48,6 +54,14 @@ fun main() {
batchCount = 2
prompt = "cat"
negativePrompt = "bad, low quality, nsfw"
if (startingImagePath.exists()) {
val image = Image.newBuilder().apply {
format = StableDiffusion.ImageFormat.png
data = ByteString.copyFrom(startingImagePath.readBytes())
}.build()
startingImage = image
}
}.build()
val generateImagesResponse = client.imageGenerationServiceBlocking.generateImages(request)

View File

@ -219,6 +219,40 @@ message GenerateImagesRequest {
* Zero indicates that the seed should be random.
*/
uint32 seed = 7;
/**
* An optional starting image to use for generation.
*/
Image starting_image = 8;
/**
* Indicates whether to enable the safety check network, if it is available.
*/
bool enable_safety_check = 9;
/**
* The scheduler to use for generation.
* The default is PNDM, if not specified.
*/
Scheduler scheduler = 10;
/**
* The guidance scale, which controls the influence the prompt has on the image.
* If not specified, a reasonable default value is used.
*/
float guidance_scale = 11;
/**
* The strength of the image generation.
* If not specified, a reasonable default value is used.
*/
float strength = 12;
/**
* The number of inference steps to perform.
* If not specified, a reasonable default value is used.
*/
uint32 step_count = 13;
}
/**

View File

@ -3,5 +3,6 @@ import Foundation
public enum SdCoreError: Error {
case modelNotLoaded
case imageEncodeFailed
case imageDecodeFailed
case modelNotFound
}

View File

@ -36,3 +36,20 @@ extension CGImage {
}
}
}
public extension SdImage {
func toCgImage() throws -> CGImage {
guard let dataProvider = CGDataProvider(data: data as CFData) else {
throw SdCoreError.imageDecodeFailed
}
if format == .png {
guard let image = CGImage(pngDataProviderSource: dataProvider, decode: nil, shouldInterpolate: false, intent: .defaultIntent) else {
throw SdCoreError.imageDecodeFailed
}
return image
} else {
throw SdCoreError.imageDecodeFailed
}
}
}

View File

@ -48,6 +48,31 @@ public actor ModelState {
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: request.prompt)
pipelineConfig.negativePrompt = request.negativePrompt
pipelineConfig.imageCount = Int(request.batchSize)
if request.hasStartingImage {
pipelineConfig.startingImage = try request.startingImage.toCgImage()
}
if request.guidanceScale != 0.0 {
pipelineConfig.guidanceScale = request.guidanceScale
}
if request.stepCount != 0 {
pipelineConfig.stepCount = Int(request.stepCount)
}
if request.strength != 0.0 {
pipelineConfig.strength = request.strength
}
pipelineConfig.disableSafety = !request.enableSafetyCheck
switch request.scheduler {
case .pndm: pipelineConfig.schedulerType = .pndmScheduler
case .dpmSolverPlusPlus: pipelineConfig.schedulerType = .dpmSolverMultistepScheduler
default: pipelineConfig.schedulerType = .pndmScheduler
}
var response = SdGenerateImagesResponse()
for _ in 0 ..< request.batchCount {
var seed = baseSeed

View File

@ -380,9 +380,46 @@ public struct SdGenerateImagesRequest {
/// Zero indicates that the seed should be random.
public var seed: UInt32 = 0
///*
/// An optional starting image to use for generation.
public var startingImage: SdImage {
get {return _startingImage ?? SdImage()}
set {_startingImage = newValue}
}
/// Returns true if `startingImage` has been explicitly set.
public var hasStartingImage: Bool {return self._startingImage != nil}
/// Clears the value of `startingImage`. Subsequent reads from it will return its default value.
public mutating func clearStartingImage() {self._startingImage = nil}
///*
/// Indicates whether to enable the safety check network, if it is available.
public var enableSafetyCheck: Bool = false
///*
/// The scheduler to use for generation.
/// The default is PNDM, if not specified.
public var scheduler: SdScheduler = .pndm
///*
/// The guidance scale, which controls the influence the prompt has on the image.
/// If not specified, a reasonable default value is used.
public var guidanceScale: Float = 0
///*
/// The strength of the image generation.
/// If not specified, a reasonable default value is used.
public var strength: Float = 0
///*
/// The number of inference steps to perform.
/// If not specified, a reasonable default value is used.
public var stepCount: UInt32 = 0
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
fileprivate var _startingImage: SdImage? = nil
}
///*
@ -665,6 +702,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
5: .same(proto: "prompt"),
6: .standard(proto: "negative_prompt"),
7: .same(proto: "seed"),
8: .standard(proto: "starting_image"),
9: .standard(proto: "enable_safety_check"),
10: .same(proto: "scheduler"),
11: .standard(proto: "guidance_scale"),
12: .same(proto: "strength"),
13: .standard(proto: "step_count"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -680,12 +723,22 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
case 5: try { try decoder.decodeSingularStringField(value: &self.prompt) }()
case 6: try { try decoder.decodeSingularStringField(value: &self.negativePrompt) }()
case 7: try { try decoder.decodeSingularUInt32Field(value: &self.seed) }()
case 8: try { try decoder.decodeSingularMessageField(value: &self._startingImage) }()
case 9: try { try decoder.decodeSingularBoolField(value: &self.enableSafetyCheck) }()
case 10: try { try decoder.decodeSingularEnumField(value: &self.scheduler) }()
case 11: try { try decoder.decodeSingularFloatField(value: &self.guidanceScale) }()
case 12: try { try decoder.decodeSingularFloatField(value: &self.strength) }()
case 13: try { try decoder.decodeSingularUInt32Field(value: &self.stepCount) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every if/case branch local when no optimizations
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
// https://github.com/apple/swift-protobuf/issues/1182
if !self.modelName.isEmpty {
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
}
@ -707,6 +760,24 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
if self.seed != 0 {
try visitor.visitSingularUInt32Field(value: self.seed, fieldNumber: 7)
}
try { if let v = self._startingImage {
try visitor.visitSingularMessageField(value: v, fieldNumber: 8)
} }()
if self.enableSafetyCheck != false {
try visitor.visitSingularBoolField(value: self.enableSafetyCheck, fieldNumber: 9)
}
if self.scheduler != .pndm {
try visitor.visitSingularEnumField(value: self.scheduler, fieldNumber: 10)
}
if self.guidanceScale != 0 {
try visitor.visitSingularFloatField(value: self.guidanceScale, fieldNumber: 11)
}
if self.strength != 0 {
try visitor.visitSingularFloatField(value: self.strength, fieldNumber: 12)
}
if self.stepCount != 0 {
try visitor.visitSingularUInt32Field(value: self.stepCount, fieldNumber: 13)
}
try unknownFields.traverse(visitor: &visitor)
}
@ -718,6 +789,12 @@ extension SdGenerateImagesRequest: SwiftProtobuf.Message, SwiftProtobuf._Message
if lhs.prompt != rhs.prompt {return false}
if lhs.negativePrompt != rhs.negativePrompt {return false}
if lhs.seed != rhs.seed {return false}
if lhs._startingImage != rhs._startingImage {return false}
if lhs.enableSafetyCheck != rhs.enableSafetyCheck {return false}
if lhs.scheduler != rhs.scheduler {return false}
if lhs.guidanceScale != rhs.guidanceScale {return false}
if lhs.strength != rhs.strength {return false}
if lhs.stepCount != rhs.stepCount {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}

View File

@ -11,9 +11,14 @@ class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
}
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
do {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.generate(request)
} catch {
print(error)
throw error
}
return try await state.generate(request)
}
}