Start work on more clean client.

This commit is contained in:
Alex Zenla 2023-04-22 16:32:54 -07:00
parent 1c0fbe02db
commit 4430bdcdd7
Signed by: alex
GPG Key ID: C0780728420EBFE5
9 changed files with 69 additions and 26 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ DerivedData/
/models
/.vscode
/Package.resolved
/output.png

View File

@ -6,6 +6,7 @@ option swift_prefix = "Sd";
message ModelInfo {
string name = 1;
string attention = 2;
bool is_loaded = 3;
}
message Image {

View File

@ -15,6 +15,7 @@ public actor ModelManager {
public func reloadModels() throws {
modelInfos.removeAll()
modelUrls.removeAll()
modelStates.removeAll()
let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey])
for subdirectoryURL in contents {
@ -29,6 +30,22 @@ public actor ModelManager {
Array(modelInfos.values)
}
public func createModelState(name: String) throws -> ModelState {
let state = modelStates[name]
let url = modelUrls[name]
guard let url else {
throw SdCoreError.modelNotFound
}
if state == nil {
let state = ModelState(url: url)
modelStates[name] = state
return state
} else {
return state!
}
}
public func getModelState(name: String) -> ModelState? {
modelStates[name]
}
@ -40,7 +57,6 @@ public actor ModelManager {
info.attention = attention ?? "unknown"
modelInfos[info.name] = info
modelUrls[info.name] = url
modelStates[info.name] = try ModelState(url: url)
}
private func getModelAttention(_ url: URL) -> String? {

View File

@ -8,13 +8,13 @@ public actor ModelState {
private var pipeline: StableDiffusionPipeline?
private var tokenizer: BPETokenizer?
public init(url: URL) throws {
public init(url: URL) {
self.url = url
}
public func load() throws {
let config = MLModelConfiguration()
config.computeUnits = .all
config.computeUnits = .cpuAndGPU
pipeline = try StableDiffusionPipeline(
resourcesAt: url,
controlNet: [],
@ -25,6 +25,7 @@ public actor ModelState {
let mergesUrl = url.appending(component: "merges.txt")
let vocabUrl = url.appending(component: "vocab.json")
tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl)
try pipeline?.loadResources()
}
public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse {

View File

@ -117,6 +117,8 @@ public struct SdModelInfo {
public var attention: String = String()
public var isLoaded: Bool = false
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
@ -272,6 +274,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "name"),
2: .same(proto: "attention"),
3: .standard(proto: "is_loaded"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -282,6 +285,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
switch fieldNumber {
case 1: try { try decoder.decodeSingularStringField(value: &self.name) }()
case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }()
case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }()
default: break
}
}
@ -294,12 +298,16 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati
if !self.attention.isEmpty {
try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2)
}
if self.isLoaded != false {
try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool {
if lhs.name != rhs.name {return false}
if lhs.attention != rhs.attention {return false}
if lhs.isLoaded != rhs.isLoaded {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}

View File

@ -23,9 +23,7 @@ class ModelServiceProvider: SdModelServiceAsyncProvider {
}
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
let state = try await modelManager.createModelState(name: request.modelName)
try await state.load()
return SdLoadModelResponse()
}

View File

@ -15,13 +15,11 @@ struct ServerCommand: ParsableCommand {
let semaphore = DispatchSemaphore(value: 0)
Task {
print("Loading initial models...")
do {
try await modelManager.reloadModels()
} catch {
ServerCommand.exit(withError: error)
}
print("Loaded initial models.")
semaphore.signal()
}
semaphore.wait()

View File

@ -0,0 +1,30 @@
import Foundation
import GRPC
import NIO
import StableDiffusionCore
import StableDiffusionProtos
struct StableDiffusionClient {
let group: EventLoopGroup
let channel: GRPCChannel
let modelService: SdModelServiceAsyncClient
let imageGenerationService: SdImageGenerationServiceAsyncClient
init(connectionTarget: ConnectionTarget, transportSecurity: GRPCChannelPool.Configuration.TransportSecurity) throws {
group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
channel = try GRPCChannelPool.with(
target: connectionTarget,
transportSecurity: transportSecurity,
eventLoopGroup: group
)
modelService = SdModelServiceAsyncClient(channel: channel)
imageGenerationService = SdImageGenerationServiceAsyncClient(channel: channel)
}
func close() async throws {
try await group.shutdownGracefully()
}
}

View File

@ -4,26 +4,14 @@ import NIO
import StableDiffusionProtos
import System
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
defer {
try? group.syncShutdownGracefully()
}
let channel = try GRPCChannelPool.with(
target: .host("localhost", port: 4546),
transportSecurity: .plaintext,
eventLoopGroup: group
)
let modelService = SdModelServiceAsyncClient(channel: channel)
let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel)
let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port: 4546), transportSecurity: .plaintext)
Task { @MainActor in
do {
let modelListResponse = try await modelService.listModels(.init())
let modelListResponse = try await client.modelService.listModels(.init())
print("Loading model...")
let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }!
_ = try await modelService.loadModel(.with { request in
_ = try await client.modelService.loadModel(.with { request in
request.modelName = modelInfo.name
})
print("Loaded model.")
@ -35,9 +23,11 @@ Task { @MainActor in
$0.imageCount = 1
}
let response = try await imageGeneratorService.generateImage(request)
print("Generated image.")
print(response)
let response = try await client.imageGenerationService.generateImage(request)
let image = response.images.first!
try image.content.write(to: URL(filePath: "output.png"))
print("Generated image to output.png")
exit(0)
} catch {
print(error)
exit(1)