mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-02 21:20:55 +00:00
Start work on more clean client.
This commit is contained in:
parent
1c0fbe02db
commit
4430bdcdd7
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ DerivedData/
|
||||
/models
|
||||
/.vscode
|
||||
/Package.resolved
|
||||
/output.png
|
||||
|
@ -6,6 +6,7 @@ option swift_prefix = "Sd";
|
||||
message ModelInfo {
|
||||
string name = 1;
|
||||
string attention = 2;
|
||||
bool is_loaded = 3;
|
||||
}
|
||||
|
||||
message Image {
|
||||
|
@ -15,6 +15,7 @@ public actor ModelManager {
|
||||
|
||||
public func reloadModels() throws {
|
||||
modelInfos.removeAll()
|
||||
modelUrls.removeAll()
|
||||
modelStates.removeAll()
|
||||
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
|
||||
for subdirectoryURL in contents {
|
||||
@ -29,6 +30,22 @@ public actor ModelManager {
|
||||
Array(modelInfos.values)
|
||||
}
|
||||
|
||||
public func createModelState(name: String) throws -> ModelState {
|
||||
let state = modelStates[name]
|
||||
let url = modelUrls[name]
|
||||
guard let url else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
let state = ModelState(url: url)
|
||||
modelStates[name] = state
|
||||
return state
|
||||
} else {
|
||||
return state!
|
||||
}
|
||||
}
|
||||
|
||||
public func getModelState(name: String) -> ModelState? {
|
||||
modelStates[name]
|
||||
}
|
||||
@ -40,7 +57,6 @@ public actor ModelManager {
|
||||
info.attention = attention ?? "unknown"
|
||||
modelInfos[info.name] = info
|
||||
modelUrls[info.name] = url
|
||||
modelStates[info.name] = try ModelState(url: url)
|
||||
}
|
||||
|
||||
private func getModelAttention(_ url: URL) -> String? {
|
||||
|
@ -8,13 +8,13 @@ public actor ModelState {
|
||||
private var pipeline: StableDiffusionPipeline?
|
||||
private var tokenizer: BPETokenizer?
|
||||
|
||||
public init(url: URL) throws {
|
||||
public init(url: URL) {
|
||||
self.url = url
|
||||
}
|
||||
|
||||
public func load() throws {
|
||||
let config = MLModelConfiguration()
|
||||
config.computeUnits = .all
|
||||
config.computeUnits = .cpuAndGPU
|
||||
pipeline = try StableDiffusionPipeline(
|
||||
resourcesAt: url,
|
||||
controlNet: [],
|
||||
@ -25,6 +25,7 @@ public actor ModelState {
|
||||
let mergesUrl = url.appending(component: "merges.txt")
|
||||
let vocabUrl = url.appending(component: "vocab.json")
|
||||
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
|
||||
try pipeline?.loadResources()
|
||||
}
|
||||
|
||||
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {
|
||||
|
@ -117,6 +117,8 @@ public struct SdModelInfo {
|
||||
|
||||
public var attention: String = String()
|
||||
|
||||
public var isLoaded: Bool = false
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
@ -272,6 +274,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .same(proto: "name"),
|
||||
2: .same(proto: "attention"),
|
||||
3: .standard(proto: "is_loaded"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
@ -282,6 +285,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularStringField(value: &self.name) }()
|
||||
case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }()
|
||||
case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
@ -294,12 +298,16 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
|
||||
if !self.attention.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2)
|
||||
}
|
||||
if self.isLoaded != false {
|
||||
try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool {
|
||||
if lhs.name != rhs.name {return false}
|
||||
if lhs.attention != rhs.attention {return false}
|
||||
if lhs.isLoaded != rhs.isLoaded {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
|
@ -23,9 +23,7 @@ class ModelServiceProvider: SdModelServiceAsyncProvider {
|
||||
}
|
||||
|
||||
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
let state = try await modelManager.createModelState(name: request.modelName)
|
||||
try await state.load()
|
||||
return SdLoadModelResponse()
|
||||
}
|
||||
|
@ -15,13 +15,11 @@ struct ServerCommand: ParsableCommand {
|
||||
|
||||
let semaphore = DispatchSemaphore(value: 0)
|
||||
Task {
|
||||
print("Loading initial models...")
|
||||
do {
|
||||
try await modelManager.reloadModels()
|
||||
} catch {
|
||||
ServerCommand.exit(withError: error)
|
||||
}
|
||||
print("Loaded initial models.")
|
||||
semaphore.signal()
|
||||
}
|
||||
semaphore.wait()
|
||||
|
@ -0,0 +1,30 @@
|
||||
import Foundation
|
||||
import GRPC
|
||||
import NIO
|
||||
import StableDiffusionCore
|
||||
import StableDiffusionProtos
|
||||
|
||||
struct StableDiffusionClient {
|
||||
let group: EventLoopGroup
|
||||
let channel: GRPCChannel
|
||||
|
||||
let modelService: SdModelServiceAsyncClient
|
||||
let imageGenerationService: SdImageGenerationServiceAsyncClient
|
||||
|
||||
init(connectionTarget: ConnectionTarget, transportSecurity: GRPCChannelPool.Configuration.TransportSecurity) throws {
|
||||
group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
|
||||
|
||||
channel = try GRPCChannelPool.with(
|
||||
target: connectionTarget,
|
||||
transportSecurity: transportSecurity,
|
||||
eventLoopGroup: group
|
||||
)
|
||||
|
||||
modelService = SdModelServiceAsyncClient(channel: channel)
|
||||
imageGenerationService = SdImageGenerationServiceAsyncClient(channel: channel)
|
||||
}
|
||||
|
||||
func close() async throws {
|
||||
try await group.shutdownGracefully()
|
||||
}
|
||||
}
|
@ -4,26 +4,14 @@ import NIO
|
||||
import StableDiffusionProtos
|
||||
import System
|
||||
|
||||
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
|
||||
defer {
|
||||
try? group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let channel = try GRPCChannelPool.with(
|
||||
target: .host("localhost", port: 4546),
|
||||
transportSecurity: .plaintext,
|
||||
eventLoopGroup: group
|
||||
)
|
||||
|
||||
let modelService = SdModelServiceAsyncClient(channel: channel)
|
||||
let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel)
|
||||
let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port: 4546), transportSecurity: .plaintext)
|
||||
|
||||
Task { @MainActor in
|
||||
do {
|
||||
let modelListResponse = try await modelService.listModels(.init())
|
||||
let modelListResponse = try await client.modelService.listModels(.init())
|
||||
print("Loading model...")
|
||||
let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }!
|
||||
_ = try await modelService.loadModel(.with { request in
|
||||
_ = try await client.modelService.loadModel(.with { request in
|
||||
request.modelName = modelInfo.name
|
||||
})
|
||||
print("Loaded model.")
|
||||
@ -35,9 +23,11 @@ Task { @MainActor in
|
||||
$0.imageCount = 1
|
||||
}
|
||||
|
||||
let response = try await imageGeneratorService.generateImage(request)
|
||||
print("Generated image.")
|
||||
print(response)
|
||||
let response = try await client.imageGenerationService.generateImage(request)
|
||||
let image = response.images.first!
|
||||
try image.content.write(to: URL(filePath: "output.png"))
|
||||
print("Generated image to output.png")
|
||||
exit(0)
|
||||
} catch {
|
||||
print(error)
|
||||
exit(1)
|
||||
|
Loading…
Reference in New Issue
Block a user