Start work on more clean client.

This commit is contained in:
2023-04-22 16:32:54 -07:00
parent 1c0fbe02db
commit 4430bdcdd7
9 changed files with 69 additions and 26 deletions

View File

@ -0,0 +1,30 @@
import Foundation
import GRPC
import NIO
import StableDiffusionCore
import StableDiffusionProtos
struct StableDiffusionClient {
let group: EventLoopGroup
let channel: GRPCChannel
let modelService: SdModelServiceAsyncClient
let imageGenerationService: SdImageGenerationServiceAsyncClient
init(connectionTarget: ConnectionTarget, transportSecurity: GRPCChannelPool.Configuration.TransportSecurity) throws {
group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
channel = try GRPCChannelPool.with(
target: connectionTarget,
transportSecurity: transportSecurity,
eventLoopGroup: group
)
modelService = SdModelServiceAsyncClient(channel: channel)
imageGenerationService = SdImageGenerationServiceAsyncClient(channel: channel)
}
func close() async throws {
try await group.shutdownGracefully()
}
}

View File

@ -4,26 +4,14 @@ import NIO
import StableDiffusionProtos
import System
let group = PlatformSupport.makeEventLoopGroup(loopCount: 1)
defer {
try? group.syncShutdownGracefully()
}
let channel = try GRPCChannelPool.with(
target: .host("localhost", port: 4546),
transportSecurity: .plaintext,
eventLoopGroup: group
)
let modelService = SdModelServiceAsyncClient(channel: channel)
let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel)
let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port: 4546), transportSecurity: .plaintext)
Task { @MainActor in
do {
let modelListResponse = try await modelService.listModels(.init())
let modelListResponse = try await client.modelService.listModels(.init())
print("Loading model...")
let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }!
_ = try await modelService.loadModel(.with { request in
_ = try await client.modelService.loadModel(.with { request in
request.modelName = modelInfo.name
})
print("Loaded model.")
@ -35,9 +23,11 @@ Task { @MainActor in
$0.imageCount = 1
}
let response = try await imageGeneratorService.generateImage(request)
print("Generated image.")
print(response)
let response = try await client.imageGenerationService.generateImage(request)
let image = response.images.first!
try image.content.write(to: URL(filePath: "output.png"))
print("Generated image to output.png")
exit(0)
} catch {
print(error)
exit(1)