From f61fe6a18f028c48385bb848aa1635feb8843756 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Sun, 23 Apr 2023 18:49:52 -0700 Subject: [PATCH] Add support for BPE tokenization. --- Clients/Java/build.gradle.kts | 2 + .../gay/pizza/stable/diffusion/sample/main.kt | 57 +++- .../diffusion/StableDiffusionRpcClient.kt | 18 +- Common/StableDiffusion.proto | 40 +++ Sources/StableDiffusionCore/ModelState.swift | 11 + .../StableDiffusion.grpc.swift | 311 ++++++++++++++++++ .../StableDiffusion.pb.swift | 118 +++++++ .../TokenizerService.swift | 19 ++ Sources/StableDiffusionServer/main.swift | 3 +- 9 files changed, 562 insertions(+), 17 deletions(-) create mode 100644 Sources/StableDiffusionServer/TokenizerService.swift diff --git a/Clients/Java/build.gradle.kts b/Clients/Java/build.gradle.kts index b1351a2..82ce388 100644 --- a/Clients/Java/build.gradle.kts +++ b/Clients/Java/build.gradle.kts @@ -35,6 +35,8 @@ dependencies { api("io.grpc:grpc-stub:1.54.1") api("io.grpc:grpc-protobuf:1.54.1") api("io.grpc:grpc-kotlin-stub:1.3.0") + implementation("com.google.protobuf:protobuf-java:3.22.3") + implementation("io.grpc:grpc-netty:1.54.1") } diff --git a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt index 8206d53..c248231 100644 --- a/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt +++ b/Clients/Java/sample/src/main/kotlin/gay/pizza/stable/diffusion/sample/main.kt @@ -1,11 +1,7 @@ package gay.pizza.stable.diffusion.sample import com.google.protobuf.ByteString -import gay.pizza.stable.diffusion.StableDiffusion -import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest -import gay.pizza.stable.diffusion.StableDiffusion.Image -import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest -import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest +import gay.pizza.stable.diffusion.StableDiffusion.* import gay.pizza.stable.diffusion.StableDiffusionRpcClient import io.grpc.ManagedChannelBuilder import kotlin.io.path.Path @@ -14,7 +10,11 @@ import kotlin.io.path.readBytes import kotlin.io.path.writeBytes import kotlin.system.exitProcess -fun main() { +fun main(args: Array) { + val chosenModelName = if (args.isNotEmpty()) args[0] else null + val chosenPrompt = if (args.size >= 2) args[1] else "cat" + val chosenNegativePrompt = if (args.size >= 3) args[2] else "bad, nsfw, low quality" + val channel = ManagedChannelBuilder .forAddress("127.0.0.1", 4546) .usePlaintext() @@ -32,7 +32,12 @@ fun main() { println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}") } - val model = modelListResponse.availableModelsList.random() + val model = if (chosenModelName == null) { + modelListResponse.availableModelsList.random() + } else { + modelListResponse.availableModelsList.first { it.name == chosenModelName } + } + if (!model.isLoaded) { println("loading model ${model.name}...") client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply { @@ -43,20 +48,39 @@ fun main() { println("using model ${model.name}...") } + println("tokenizing prompts...") + + val tokenizePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply { + modelName = model.name + input = chosenPrompt + }.build()) + val tokenizeNegativePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply { + modelName = model.name + input = chosenNegativePrompt + }.build()) + + println("tokenize prompt='${chosenPrompt}' " + + "tokens=[${tokenizePromptResponse.tokensList.joinToString(", ")}] " + + "token_ids=[${tokenizePromptResponse.tokenIdsList.joinToString(", ")}]") + + println("tokenize negative_prompt='${chosenNegativePrompt}' " + + "tokens=[${tokenizeNegativePromptResponse.tokensList.joinToString(", ")}] " + + "token_ids=[${tokenizeNegativePromptResponse.tokenIdsList.joinToString(", ")}]") + println("generating images...") val startingImagePath = Path("work/start.png") val request = GenerateImagesRequest.newBuilder().apply { modelName = model.name - outputImageFormat = StableDiffusion.ImageFormat.png + outputImageFormat = ImageFormat.png batchSize = 2 batchCount = 2 - prompt = "cat" - negativePrompt = "bad, low quality, nsfw" + prompt = chosenPrompt + negativePrompt = chosenNegativePrompt if (startingImagePath.exists()) { val image = Image.newBuilder().apply { - format = StableDiffusion.ImageFormat.png + format = ImageFormat.png data = ByteString.copyFrom(startingImagePath.readBytes()) }.build() @@ -65,10 +89,12 @@ fun main() { }.build() for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) { if (update.hasBatchProgress()) { - println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%") + println("batch=${update.currentBatch} " + + "progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " + + "overall=${prettyProgressValue(update.overallPercentageComplete)}%") for ((index, image) in update.batchProgress.imagesList.withIndex()) { val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) - println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") + println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}") path.writeBytes(image.data.toByteArray()) } @@ -77,13 +103,14 @@ fun main() { if (update.hasBatchCompleted()) { for ((index, image) in update.batchCompleted.imagesList.withIndex()) { val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) - println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") + println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") val path = Path("work/final_${imageIndex}.${image.format.name}") path.writeBytes(image.data.toByteArray()) } } - println("overall progress ${update.overallPercentageComplete}%") } channel.shutdownNow() } + +fun prettyProgressValue(value: Float) = String.format("%.2f", value) diff --git a/Clients/Java/src/main/kotlin/gay/pizza/stable/diffusion/StableDiffusionRpcClient.kt b/Clients/Java/src/main/kotlin/gay/pizza/stable/diffusion/StableDiffusionRpcClient.kt index 59846e5..8da6935 100644 --- a/Clients/Java/src/main/kotlin/gay/pizza/stable/diffusion/StableDiffusionRpcClient.kt +++ b/Clients/Java/src/main/kotlin/gay/pizza/stable/diffusion/StableDiffusionRpcClient.kt @@ -2,7 +2,7 @@ package gay.pizza.stable.diffusion import io.grpc.Channel -@Suppress("MemberVisibilityCanBePrivate") +@Suppress("MemberVisibilityCanBePrivate", "unused") class StableDiffusionRpcClient(val channel: Channel) { val modelService: ModelServiceGrpc.ModelServiceStub by lazy { ModelServiceGrpc.newStub(channel) @@ -35,4 +35,20 @@ class StableDiffusionRpcClient(val channel: Channel) { val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy { ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel) } + + val tokenizerService: TokenizerServiceGrpc.TokenizerServiceStub by lazy { + TokenizerServiceGrpc.newStub(channel) + } + + val tokenizerServiceBlocking: TokenizerServiceGrpc.TokenizerServiceBlockingStub by lazy { + TokenizerServiceGrpc.newBlockingStub(channel) + } + + val tokenizerServiceFuture: TokenizerServiceGrpc.TokenizerServiceFutureStub by lazy { + TokenizerServiceGrpc.newFutureStub(channel) + } + + val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy { + TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel) + } } diff --git a/Common/StableDiffusion.proto b/Common/StableDiffusion.proto index b3fedfe..9e83b15 100644 --- a/Common/StableDiffusion.proto +++ b/Common/StableDiffusion.proto @@ -350,3 +350,43 @@ service ImageGenerationService { */ rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate); } + +/** + * Represents a request to tokenize an input. + */ +message TokenizeRequest { + /** + * The name of a loaded model to use for tokenization. + */ + string model_name = 1; + + /** + * The input string to tokenize. + */ + string input = 2; +} + +/** + * Represents a response to tokenization. + */ +message TokenizeResponse { + /** + * The tokens inside the input string. + */ + repeated string tokens = 1; + + /** + * The token IDs inside the input string. + */ + repeated uint64 token_ids = 2; +} + +/** + * The tokenizer service, for analyzing tokens for a loaded model. + */ +service TokenizerService { + /** + * Analyze the input using a loaded model and return the results. + */ + rpc Tokenize(TokenizeRequest) returns (TokenizeResponse); +} diff --git a/Sources/StableDiffusionCore/ModelState.swift b/Sources/StableDiffusionCore/ModelState.swift index de3844f..84f8d7b 100644 --- a/Sources/StableDiffusionCore/ModelState.swift +++ b/Sources/StableDiffusionCore/ModelState.swift @@ -154,4 +154,15 @@ public actor ModelState { } return pipelineConfig } + + public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse { + guard let tokenizer else { + throw SdCoreError.modelNotLoaded + } + let results = tokenizer.tokenize(input: request.input) + var response = SdTokenizeResponse() + response.tokens = results.tokens + response.tokenIds = results.tokenIDs.map { UInt64($0) } + return response + } } diff --git a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift index 96e3a24..17d962a 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.grpc.swift @@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata { } } +///* +/// The tokenizer service, for analyzing tokens for a loaded model. +/// +/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls. +public protocol SdTokenizerServiceClientProtocol: GRPCClient { + var serviceName: String { get } + var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get } + + func tokenize( + _ request: SdTokenizeRequest, + callOptions: CallOptions? + ) -> UnaryCall +} + +extension SdTokenizerServiceClientProtocol { + public var serviceName: String { + return "gay.pizza.stable.diffusion.TokenizerService" + } + + ///* + /// Analyze the input using a loaded model and return the results. + /// + /// - Parameters: + /// - request: Request to send to Tokenize. + /// - callOptions: Call options. + /// - Returns: A `UnaryCall` with futures for the metadata, status and response. + public func tokenize( + _ request: SdTokenizeRequest, + callOptions: CallOptions? = nil + ) -> UnaryCall { + return self.makeUnaryCall( + path: SdTokenizerServiceClientMetadata.Methods.tokenize.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [] + ) + } +} + +#if compiler(>=5.6) +@available(*, deprecated) +extension SdTokenizerServiceClient: @unchecked Sendable {} +#endif // compiler(>=5.6) + +@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient") +public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol { + private let lock = Lock() + private var _defaultCallOptions: CallOptions + private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? + public let channel: GRPCChannel + public var defaultCallOptions: CallOptions { + get { self.lock.withLock { return self._defaultCallOptions } } + set { self.lock.withLockVoid { self._defaultCallOptions = newValue } } + } + public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { + get { self.lock.withLock { return self._interceptors } } + set { self.lock.withLockVoid { self._interceptors = newValue } } + } + + /// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self._defaultCallOptions = defaultCallOptions + self._interceptors = interceptors + } +} + +public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? + + /// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service. + /// + /// - Parameters: + /// - channel: `GRPCChannel` to the service host. + /// - defaultCallOptions: Options to use for each service call if the user doesn't provide them. + /// - interceptors: A factory providing interceptors for each RPC. + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#if compiler(>=5.6) +///* +/// The tokenizer service, for analyzing tokens for a loaded model. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get } + + func makeTokenizeCall( + _ request: SdTokenizeRequest, + callOptions: CallOptions? + ) -> GRPCAsyncUnaryCall +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdTokenizerServiceAsyncClientProtocol { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdTokenizerServiceClientMetadata.serviceDescriptor + } + + public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { + return nil + } + + public func makeTokenizeCall( + _ request: SdTokenizeRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncUnaryCall { + return self.makeAsyncUnaryCall( + path: SdTokenizerServiceClientMetadata.Methods.tokenize.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdTokenizerServiceAsyncClientProtocol { + public func tokenize( + _ request: SdTokenizeRequest, + callOptions: CallOptions? = nil + ) async throws -> SdTokenizeResponse { + return try await self.performAsyncUnaryCall( + path: SdTokenizerServiceClientMetadata.Methods.tokenize.path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [] + ) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol { + public var channel: GRPCChannel + public var defaultCallOptions: CallOptions + public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? + + public init( + channel: GRPCChannel, + defaultCallOptions: CallOptions = CallOptions(), + interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil + ) { + self.channel = channel + self.defaultCallOptions = defaultCallOptions + self.interceptors = interceptors + } +} + +#endif // compiler(>=5.6) + +public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable { + + /// - Returns: Interceptors to use when invoking 'tokenize'. + func makeTokenizeInterceptors() -> [ClientInterceptor] +} + +public enum SdTokenizerServiceClientMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "TokenizerService", + fullName: "gay.pizza.stable.diffusion.TokenizerService", + methods: [ + SdTokenizerServiceClientMetadata.Methods.tokenize, + ] + ) + + public enum Methods { + public static let tokenize = GRPCMethodDescriptor( + name: "Tokenize", + path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize", + type: GRPCCallType.unary + ) + } +} + ///* /// The model service, for management and loading of models. /// @@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata { ) } } +///* +/// The tokenizer service, for analyzing tokens for a loaded model. +/// +/// To build a server, implement a class that conforms to this protocol. +public protocol SdTokenizerServiceProvider: CallHandlerProvider { + var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get } + + ///* + /// Analyze the input using a loaded model and return the results. + func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture +} + +extension SdTokenizerServiceProvider { + public var serviceName: Substring { + return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...] + } + + /// Determines, calls and returns the appropriate request handler, depending on the request's method. + /// Returns nil for methods not handled by this service. + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "Tokenize": + return UnaryServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [], + userFunction: self.tokenize(request:context:) + ) + + default: + return nil + } + } +} + +#if compiler(>=5.6) + +///* +/// The tokenizer service, for analyzing tokens for a loaded model. +/// +/// To implement a server, implement an object which conforms to this protocol. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider { + static var serviceDescriptor: GRPCServiceDescriptor { get } + var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get } + + ///* + /// Analyze the input using a loaded model and return the results. + @Sendable func tokenize( + request: SdTokenizeRequest, + context: GRPCAsyncServerCallContext + ) async throws -> SdTokenizeResponse +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension SdTokenizerServiceAsyncProvider { + public static var serviceDescriptor: GRPCServiceDescriptor { + return SdTokenizerServiceServerMetadata.serviceDescriptor + } + + public var serviceName: Substring { + return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...] + } + + public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { + return nil + } + + public func handle( + method name: Substring, + context: CallHandlerContext + ) -> GRPCServerHandlerProtocol? { + switch name { + case "Tokenize": + return GRPCAsyncServerHandler( + context: context, + requestDeserializer: ProtobufDeserializer(), + responseSerializer: ProtobufSerializer(), + interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [], + wrapping: self.tokenize(request:context:) + ) + + default: + return nil + } + } +} + +#endif // compiler(>=5.6) + +public protocol SdTokenizerServiceServerInterceptorFactoryProtocol { + + /// - Returns: Interceptors to use when handling 'tokenize'. + /// Defaults to calling `self.makeInterceptors()`. + func makeTokenizeInterceptors() -> [ServerInterceptor] +} + +public enum SdTokenizerServiceServerMetadata { + public static let serviceDescriptor = GRPCServiceDescriptor( + name: "TokenizerService", + fullName: "gay.pizza.stable.diffusion.TokenizerService", + methods: [ + SdTokenizerServiceServerMetadata.Methods.tokenize, + ] + ) + + public enum Methods { + public static let tokenize = GRPCMethodDescriptor( + name: "Tokenize", + path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize", + type: GRPCCallType.unary + ) + } +} diff --git a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift index b05b458..5c587ba 100644 --- a/Sources/StableDiffusionProtos/StableDiffusion.pb.swift +++ b/Sources/StableDiffusionProtos/StableDiffusion.pb.swift @@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate { public init() {} } +///* +/// Represents a request to tokenize an input. +public struct SdTokenizeRequest { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + ///* + /// The name of a loaded model to use for tokenization. + public var modelName: String = String() + + ///* + /// The input string to tokenize. + public var input: String = String() + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + +///* +/// Represents a response to tokenization. +public struct SdTokenizeResponse { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + ///* + /// The tokens inside the input string. + public var tokens: [String] = [] + + ///* + /// The token IDs inside the input string. + public var tokenIds: [UInt64] = [] + + public var unknownFields = SwiftProtobuf.UnknownStorage() + + public init() {} +} + #if swift(>=5.5) && canImport(_Concurrency) extension SdModelAttention: @unchecked Sendable {} extension SdScheduler: @unchecked Sendable {} @@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {} extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {} extension SdGenerateImagesStreamUpdate: @unchecked Sendable {} extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {} +extension SdTokenizeRequest: @unchecked Sendable {} +extension SdTokenizeResponse: @unchecked Sendable {} #endif // swift(>=5.5) && canImport(_Concurrency) // MARK: - Code below here is support for the SwiftProtobuf runtime. @@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me return true } } + +extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "model_name"), + 2: .same(proto: "input"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }() + case 2: try { try decoder.decodeSingularStringField(value: &self.input) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.modelName.isEmpty { + try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1) + } + if !self.input.isEmpty { + try visitor.visitSingularStringField(value: self.input, fieldNumber: 2) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool { + if lhs.modelName != rhs.modelName {return false} + if lhs.input != rhs.input {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse" + public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .same(proto: "tokens"), + 2: .standard(proto: "token_ids"), + ] + + public mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }() + case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }() + default: break + } + } + } + + public func traverse(visitor: inout V) throws { + if !self.tokens.isEmpty { + try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1) + } + if !self.tokenIds.isEmpty { + try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2) + } + try unknownFields.traverse(visitor: &visitor) + } + + public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool { + if lhs.tokens != rhs.tokens {return false} + if lhs.tokenIds != rhs.tokenIds {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} diff --git a/Sources/StableDiffusionServer/TokenizerService.swift b/Sources/StableDiffusionServer/TokenizerService.swift new file mode 100644 index 0000000..b2d5d6c --- /dev/null +++ b/Sources/StableDiffusionServer/TokenizerService.swift @@ -0,0 +1,19 @@ +import Foundation +import GRPC +import StableDiffusionCore +import StableDiffusionProtos + +class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider { + private let modelManager: ModelManager + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse { + guard let state = await modelManager.getModelState(name: request.modelName) else { + throw SdCoreError.modelNotFound + } + return try await state.tokenize(request) + } +} diff --git a/Sources/StableDiffusionServer/main.swift b/Sources/StableDiffusionServer/main.swift index da70322..53e9a3b 100644 --- a/Sources/StableDiffusionServer/main.swift +++ b/Sources/StableDiffusionServer/main.swift @@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand { _ = Server.insecure(group: group) .withServiceProviders([ ModelServiceProvider(modelManager: modelManager), - ImageGenerationServiceProvider(modelManager: modelManager) + ImageGenerationServiceProvider(modelManager: modelManager), + TokenizerServiceProvider(modelManager: modelManager) ]) .bind(host: bindHost, port: bindPort)