Start work on more clean client.

This commit is contained in:
2023-04-22 16:32:54 -07:00
parent 1c0fbe02db
commit 4430bdcdd7
9 changed files with 69 additions and 26 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ DerivedData/
/models /models
/.vscode /.vscode
/Package.resolved /Package.resolved
/output.png

View File

@ -6,6 +6,7 @@ option swift_prefix = "Sd";
message ModelInfo { message ModelInfo {
string name = 1; string name = 1;
string attention = 2; string attention = 2;
bool is_loaded = 3;
} }
message Image { message Image {

View File

@ -15,6 +15,7 @@ public actor ModelManager {
public func reloadModels() throws { public func reloadModels() throws {
modelInfos.removeAll() modelInfos.removeAll()
modelUrls.removeAll()
modelStates.removeAll() modelStates.removeAll()
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey]) let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
for subdirectoryURL in contents { for subdirectoryURL in contents {
@ -29,6 +30,22 @@ public actor ModelManager {
Array(modelInfos.values) Array(modelInfos.values)
} }
public func createModelState(name: String) throws -> ModelState {
let state = modelStates[name]
let url = modelUrls[name]
guard let url else {
throw SdCoreError.modelNotFound
}
if state == nil {
let state = ModelState(url: url)
modelStates[name] = state
return state
} else {
return state!
}
}
public func getModelState(name: String) -> ModelState? { public func getModelState(name: String) -> ModelState? {
modelStates[name] modelStates[name]
} }
@ -40,7 +57,6 @@ public actor ModelManager {
info.attention = attention ?? "unknown" info.attention = attention ?? "unknown"
modelInfos[info.name] = info modelInfos[info.name] = info
modelUrls[info.name] = url modelUrls[info.name] = url
modelStates[info.name] = try ModelState(url: url)
} }
private func getModelAttention(_ url: URL) -> String? { private func getModelAttention(_ url: URL) -> String? {

View File

@ -8,13 +8,13 @@ public actor ModelState {
private var pipeline: StableDiffusionPipeline? private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer? private var tokenizer: BPETokenizer?
public init(url: URL) throws { public init(url: URL) {
self.url = url self.url = url
} }
public func load() throws { public func load() throws {
let config = MLModelConfiguration() let config = MLModelConfiguration()
config.computeUnits = .all config.computeUnits = .cpuAndGPU
pipeline = try StableDiffusionPipeline( pipeline = try StableDiffusionPipeline(
resourcesAt: url, resourcesAt: url,
controlNet: [], controlNet: [],
@ -25,6 +25,7 @@ public actor ModelState {
let mergesUrl = url.appending(component: "merges.txt") let mergesUrl = url.appending(component: "merges.txt")
let vocabUrl = url.appending(component: "vocab.json") let vocabUrl = url.appending(component: "vocab.json")
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl) tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
try pipeline?.loadResources()
} }
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse { public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {

View File

@ -117,6 +117,8 @@ public struct SdModelInfo {
public var attention: String = String() public var attention: String = String()
public var isLoaded: Bool = false
public var unknownFields = SwiftProtobuf.UnknownStorage() public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {} public init() {}
@ -272,6 +274,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "name"), 1: .same(proto: "name"),
2: .same(proto: "attention"), 2: .same(proto: "attention"),
3: .standard(proto: "is_loaded"),
] ]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws { public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -282,6 +285,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
switch fieldNumber { switch fieldNumber {
case 1: try { try decoder.decodeSingularStringField(value: &self.name) }() case 1: try { try decoder.decodeSingularStringField(value: &self.name) }()
case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }() case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }()
case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }()
default: break default: break
} }
} }
@ -294,12 +298,16 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
if !self.attention.isEmpty { if !self.attention.isEmpty {
try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2) try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2)
} }
if self.isLoaded != false {
try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor) try unknownFields.traverse(visitor: &visitor)
} }
public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool { public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool {
if lhs.name != rhs.name {return false} if lhs.name != rhs.name {return false}
if lhs.attention != rhs.attention {return false} if lhs.attention != rhs.attention {return false}
if lhs.isLoaded != rhs.isLoaded {return false}
if lhs.unknownFields != rhs.unknownFields {return false} if lhs.unknownFields != rhs.unknownFields {return false}
return true return true
} }

View File

@ -23,9 +23,7 @@ class ModelServiceProvider: SdModelServiceAsyncProvider {
} }
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse { func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else { let state = try await modelManager.createModelState(name: request.modelName)
throw SdCoreError.modelNotFound
}
try await state.load() try await state.load()
return SdLoadModelResponse() return SdLoadModelResponse()
} }

View File

@ -15,13 +15,11 @@ struct ServerCommand: ParsableCommand {
let semaphore = DispatchSemaphore(value: 0) let semaphore = DispatchSemaphore(value: 0)
Task { Task {
print("Loading initial models...")
do { do {
try await modelManager.reloadModels() try await modelManager.reloadModels()
} catch { } catch {
ServerCommand.exit(withError: error) ServerCommand.exit(withError: error)
} }
print("Loaded initial models.")
semaphore.signal() semaphore.signal()
} }
semaphore.wait() semaphore.wait()

View File

@ -0,0 +1,30 @@
import Foundation
import GRPC
import NIO
import StableDiffusionCore
import StableDiffusionProtos
struct StableDiffusionClient {
let group: EventLoopGroup
let channel: GRPCChannel
let modelService: SdModelServiceAsyncClient
let imageGenerationService: SdImageGenerationServiceAsyncClient
init(connectionTarget: ConnectionTarget, transportSecurity: GRPCChannelPool.Configuration.TransportSecurity) throws {
group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
channel = try GRPCChannelPool.with(
target: connectionTarget,
transportSecurity: transportSecurity,
eventLoopGroup: group
)
modelService = SdModelServiceAsyncClient(channel: channel)
imageGenerationService = SdImageGenerationServiceAsyncClient(channel: channel)
}
func close() async throws {
try await group.shutdownGracefully()
}
}

View File

@ -4,26 +4,14 @@ import NIO
import StableDiffusionProtos import StableDiffusionProtos
import System import System
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1) let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port: 4546), transportSecurity: .plaintext)
defer {
try? group.syncShutdownGracefully()
}
let channel = try GRPCChannelPool.with(
target: .host("localhost", port: 4546),
transportSecurity: .plaintext,
eventLoopGroup: group
)
let modelService = SdModelServiceAsyncClient(channel: channel)
let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel)
Task { @MainActor in Task { @MainActor in
do { do {
let modelListResponse = try await modelService.listModels(.init()) let modelListResponse = try await client.modelService.listModels(.init())
print("Loading model...") print("Loading model...")
let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }! let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }!
_ = try await modelService.loadModel(.with { request in _ = try await client.modelService.loadModel(.with { request in
request.modelName = modelInfo.name request.modelName = modelInfo.name
}) })
print("Loaded model.") print("Loaded model.")
@ -35,9 +23,11 @@ Task { @MainActor in
$0.imageCount = 1 $0.imageCount = 1
} }
let response = try await imageGeneratorService.generateImage(request) let response = try await client.imageGenerationService.generateImage(request)
print("Generated image.") let image = response.images.first!
print(response) try image.content.write(to: URL(filePath: "output.png"))
print("Generated image to output.png")
exit(0)
} catch { } catch {
print(error) print(error)
exit(1) exit(1)