diff --git a/.gitignore b/.gitignore index a43e767..9ce62ae 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ DerivedData/ /models /.vscode /Package.resolved +/output.png diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index c443f81..258d9e9 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -6,6 +6,7 @@ option swift_prefix = "Sd"; message ModelInfo { string name = 1; string attention = 2; + bool is_loaded = 3; } message Image { diff --git a/Sources/StableDiffusionCore/ModelManager.swift b/Sources/StableDiffusionCore/ModelManager.swift index 67fe366..cfdf020 100644 --- a/Sources/StableDiffusionCore/ModelManager.swift +++ b/Sources/StableDiffusionCore/ModelManager.swift @@ -15,6 +15,7 @@ public actor ModelManager { public func reloadModels() throws { modelInfos.removeAll() + modelUrls.removeAll() modelStates.removeAll() let contents = try FileManager.default.contentsOfDirectory(at: modelBaseURL.resolvingSymlinksInPath(), includingPropertiesForKeys: [.isDirectoryKey]) for subdirectoryURL in contents { @@ -29,6 +30,22 @@ public actor ModelManager { Array(modelInfos.values) } + public func createModelState(name: String) throws -> ModelState { + let state = modelStates[name] + let url = modelUrls[name] + guard let url else { + throw SdCoreError.modelNotFound + } + + if state == nil { + let state = ModelState(url: url) + modelStates[name] = state + return state + } else { + return state! + } + } + public func getModelState(name: String) -> ModelState? { modelStates[name] } @@ -40,7 +57,6 @@ public actor ModelManager { info.attention = attention ?? "unknown" modelInfos[info.name] = info modelUrls[info.name] = url - modelStates[info.name] = try ModelState(url: url) } private func getModelAttention(_ url: URL) -> String? { diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index f0ef37c..0d58c21 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -8,13 +8,13 @@ public actor ModelState { private var pipeline: StableDiffusionPipeline? private var tokenizer: BPETokenizer? - public init(url: URL) throws { + public init(url: URL) { self.url = url } public func load() throws { let config = MLModelConfiguration() - config.computeUnits = .all + config.computeUnits = .cpuAndGPU pipeline = try StableDiffusionPipeline( resourcesAt: url, controlNet: [], @@ -25,6 +25,7 @@ public actor ModelState { let mergesUrl = url.appending(component: "merges.txt") let vocabUrl = url.appending(component: "vocab.json") tokenizer = try BPETokenizer(mergesAt: mergesUrl, vocabularyAt: vocabUrl) + try pipeline?.loadResources() } public func generate(_ request: SdGenerateImagesRequest) throws -> SdGenerateImagesResponse { diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index f56ec7b..4c4360c 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -117,6 +117,8 @@ public struct SdModelInfo { public var attention: String = String() + public var isLoaded: Bool = false + public var unknownFields = SwiftProtobuf.UnknownStorage() public init() {} @@ -272,6 +274,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .same(proto: "name"), 2: .same(proto: "attention"), + 3: .standard(proto: "is_loaded"), ] public mutating func decodeMessage(decoder: inout D) throws { @@ -282,6 +285,7 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati switch fieldNumber { case 1: try { try decoder.decodeSingularStringField(value: &self.name) }() case 2: try { try decoder.decodeSingularStringField(value: &self.attention) }() + case 3: try { try decoder.decodeSingularBoolField(value: &self.isLoaded) }() default: break } } @@ -294,12 +298,16 @@ extension SdModelInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementati if !self.attention.isEmpty { try visitor.visitSingularStringField(value: self.attention, fieldNumber: 2) } + if self.isLoaded != false { + try visitor.visitSingularBoolField(value: self.isLoaded, fieldNumber: 3) + } try unknownFields.traverse(visitor: &visitor) } public static func ==(lhs: SdModelInfo, rhs: SdModelInfo) -> Bool { if lhs.name != rhs.name {return false} if lhs.attention != rhs.attention {return false} + if lhs.isLoaded != rhs.isLoaded {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } diff --git a/Sources/StableDiffusionServer/ModelService.swift b/Sources/StableDiffusionServer/ModelService.swift index 96af18a..17d2a9e 100644 --- a/Sources/StableDiffusionServer/ModelService.swift +++ b/Sources/StableDiffusionServer/ModelService.swift @@ -23,9 +23,7 @@ class ModelServiceProvider: SdModelServiceAsyncProvider { } func loadModel(request: SdLoadModelRequest, context _: GRPCAsyncServerCallContext) async throws -> SdLoadModelResponse { - guard let state = await modelManager.getModelState(name: request.modelName) else { - throw SdCoreError.modelNotFound - } + let state = try await modelManager.createModelState(name: request.modelName) try await state.load() return SdLoadModelResponse() } diff --git a/Sources/StableDiffusionServer/main.swift b/Sources/StableDiffusionServer/main.swift index 560a791..d0384ba 100644 --- a/Sources/StableDiffusionServer/main.swift +++ b/Sources/StableDiffusionServer/main.swift @@ -15,13 +15,11 @@ struct ServerCommand: ParsableCommand { let semaphore = DispatchSemaphore(value: 0) Task { - print("Loading initial models...") do { try await modelManager.reloadModels() } catch { ServerCommand.exit(withError: error) } - print("Loaded initial models.") semaphore.signal() } semaphore.wait() diff --git a/Sources/TestStableDiffusionClient/StableDiffusionClient.swift b/Sources/TestStableDiffusionClient/StableDiffusionClient.swift new file mode 100644 index 0000000..2c0d90a --- /dev/null +++ b/Sources/TestStableDiffusionClient/StableDiffusionClient.swift @@ -0,0 +1,30 @@ +import Foundation +import GRPC +import NIO +import StableDiffusionCore +import StableDiffusionProtos + +struct StableDiffusionClient { + let group: EventLoopGroup + let channel: GRPCChannel + + let modelService: SdModelServiceAsyncClient + let imageGenerationService: SdImageGenerationServiceAsyncClient + + init(connectionTarget: ConnectionTarget, transportSecurity: GRPCChannelPool.Configuration.TransportSecurity) throws { + group = PlatformSupport.makeEventLoopGroup(loopCount: 1) + + channel = try GRPCChannelPool.with( + target: connectionTarget, + transportSecurity: transportSecurity, + eventLoopGroup: group + ) + + modelService = SdModelServiceAsyncClient(channel: channel) + imageGenerationService = SdImageGenerationServiceAsyncClient(channel: channel) + } + + func close() async throws { + try await group.shutdownGracefully() + } +} diff --git a/Sources/TestStableDiffusionClient/main.swift b/Sources/TestStableDiffusionClient/main.swift index 58c4ba6..34c9098 100644 --- a/Sources/TestStableDiffusionClient/main.swift +++ b/Sources/TestStableDiffusionClient/main.swift @@ -4,26 +4,14 @@ import NIO import StableDiffusionProtos import System -let group = PlatformSupport.makeEventLoopGroup(loopCount: 1) -defer { - try? group.syncShutdownGracefully() -} - -let channel = try GRPCChannelPool.with( - target: .host("localhost", port: 4546), - transportSecurity: .plaintext, - eventLoopGroup: group -) - -let modelService = SdModelServiceAsyncClient(channel: channel) -let imageGeneratorService = SdImageGenerationServiceAsyncClient(channel: channel) +let client = try StableDiffusionClient(connectionTarget: .host("127.0.0.1", port: 4546), transportSecurity: .plaintext) Task { @MainActor in do { - let modelListResponse = try await modelService.listModels(.init()) + let modelListResponse = try await client.modelService.listModels(.init()) print("Loading model...") let modelInfo = modelListResponse.models.first { $0.name == "anything-4.5" }! - _ = try await modelService.loadModel(.with { request in + _ = try await client.modelService.loadModel(.with { request in request.modelName = modelInfo.name }) print("Loaded model.") @@ -35,9 +23,11 @@ Task { @MainActor in $0.imageCount = 1 } - let response = try await imageGeneratorService.generateImage(request) - print("Generated image.") - print(response) + let response = try await client.imageGenerationService.generateImage(request) + let image = response.images.first! + try image.content.write(to: URL(filePath: "output.png")) + print("Generated image to output.png") + exit(0) } catch { print(error) exit(1)