Split out worker related things to a separate service definition.

This commit is contained in:
2023-05-08 22:12:24 -07:00
parent ace2c07aa1
commit 2e5a37ea4b
28 changed files with 1271 additions and 359 deletions

View File

@ -0,0 +1,18 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class HostModelServiceProvider: SdHostModelServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse {
let state = try await modelManager.createModelState(name: request.modelName)
try await state.load(request: request)
return SdLoadModelResponse()
}
}

View File

@ -0,0 +1,40 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class ImageGenerationServiceProvider: SdImageGenerationServiceAsyncProvider {
private let jobManager: JobManager
private let modelManager: ModelManager
init(jobManager: JobManager, modelManager: ModelManager) {
self.jobManager = jobManager
self.modelManager = modelManager
}
func generateImages(request: SdGenerateImagesRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGenerateImagesResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
let job = await jobManager.create()
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobQueued(job)
}
}
return try await state.generate(request, job: job)
}
func generateImagesStreaming(request: SdGenerateImagesRequest, responseStream: GRPCAsyncResponseStreamWriter<SdGenerateImagesStreamUpdate>, context _: GRPCAsyncServerCallContext) async throws {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
let job = await jobManager.create()
DispatchQueue.main.async {
Task {
await self.jobManager.updateJobQueued(job)
}
}
_ = try await state.generate(request, job: job, stream: responseStream)
}
}

View File

@ -0,0 +1,37 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class JobServiceProvider: SdJobServiceAsyncProvider {
private let jobManager: JobManager
init(jobManager: JobManager) {
self.jobManager = jobManager
}
func getJob(request: SdGetJobRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGetJobResponse {
var response = SdGetJobResponse()
guard let job = await jobManager.job(id: request.id) else {
throw SdCoreError.jobNotFound
}
response.job = job
return response
}
func cancelJob(request _: SdCancelJobRequest, context _: GRPCAsyncServerCallContext) async throws -> SdCancelJobResponse {
throw SdCoreError.notImplemented
}
func streamJobUpdates(request: SdStreamJobUpdatesRequest, responseStream: GRPCAsyncResponseStreamWriter<SdJobUpdate>, context _: GRPCAsyncServerCallContext) async throws {
let isFilteredById = request.id != 0
for await job in await jobManager.jobUpdatePublisher {
if isFilteredById, job.id != request.id {
continue
}
var update = SdJobUpdate()
update.job = job
try await responseStream.send(update)
}
}
}

View File

@ -0,0 +1,19 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class ModelServiceProvider: SdModelServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func listModels(request _: SdListModelsRequest, context _: GRPCAsyncServerCallContext) async throws -> SdListModelsResponse {
let models = try await modelManager.listAvailableModels()
var response = SdListModelsResponse()
response.availableModels.append(contentsOf: models)
return response
}
}

View File

@ -0,0 +1,14 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class ServerMetadataServiceProvider: SdServerMetadataServiceAsyncProvider {
func getServerMetadata(request _: SdGetServerMetadataRequest, context _: GRPCAsyncServerCallContext) async throws -> SdGetServerMetadataResponse {
.with { response in
response.metadata = .with { metadata in
metadata.role = .node
}
}
}
}

View File

@ -0,0 +1,19 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.tokenize(request)
}
}

View File

@ -0,0 +1,52 @@
import ArgumentParser
import Foundation
import GRPC
import NIO
import StableDiffusionCore
import System
struct ServerCommand: ParsableCommand {
@Option(name: .shortAndLong, help: "Path to models directory")
var modelsDirectoryPath: String = "models"
@Option(name: .long, help: "Bind host")
var bindHost: String = "0.0.0.0"
@Option(name: .long, help: "Bind port")
var bindPort: Int = 4546
mutating func run() throws {
let jobManager = JobManager()
let modelsDirectoryURL = URL(filePath: modelsDirectoryPath)
let modelManager = ModelManager(modelBaseURL: modelsDirectoryURL, jobManager: jobManager)
let semaphore = DispatchSemaphore(value: 0)
Task {
do {
try await modelManager.reloadAvailableModels()
} catch {
ServerCommand.exit(withError: error)
}
semaphore.signal()
}
semaphore.wait()
let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
_ = Server.insecure(group: group)
.withServiceProviders([
ModelServiceProvider(modelManager: modelManager),
HostModelServiceProvider(modelManager: modelManager),
ImageGenerationServiceProvider(jobManager: jobManager, modelManager: modelManager),
TokenizerServiceProvider(modelManager: modelManager),
JobServiceProvider(jobManager: jobManager),
ServerMetadataServiceProvider()
])
.bind(host: bindHost, port: bindPort)
print("Stable Diffusion RPC node running on \(bindHost):\(bindPort)")
dispatchMain()
}
}
ServerCommand.main()