Add support for BPE tokenization.

This commit is contained in:
2023-04-23 18:49:52 -07:00
parent 0fe35cd976
commit f61fe6a18f
9 changed files with 562 additions and 17 deletions

View File

@ -35,6 +35,8 @@ dependencies {
api("io.grpc:grpc-stub:1.54.1") api("io.grpc:grpc-stub:1.54.1")
api("io.grpc:grpc-protobuf:1.54.1") api("io.grpc:grpc-protobuf:1.54.1")
api("io.grpc:grpc-kotlin-stub:1.3.0") api("io.grpc:grpc-kotlin-stub:1.3.0")
implementation("com.google.protobuf:protobuf-java:3.22.3")
implementation("io.grpc:grpc-netty:1.54.1") implementation("io.grpc:grpc-netty:1.54.1")
} }

View File

@ -1,11 +1,7 @@
package gay.pizza.stable.diffusion.sample package gay.pizza.stable.diffusion.sample
import com.google.protobuf.ByteString import com.google.protobuf.ByteString
import gay.pizza.stable.diffusion.StableDiffusion import gay.pizza.stable.diffusion.StableDiffusion.*
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
import gay.pizza.stable.diffusion.StableDiffusion.Image
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
import gay.pizza.stable.diffusion.StableDiffusionRpcClient import gay.pizza.stable.diffusion.StableDiffusionRpcClient
import io.grpc.ManagedChannelBuilder import io.grpc.ManagedChannelBuilder
import kotlin.io.path.Path import kotlin.io.path.Path
@ -14,7 +10,11 @@ import kotlin.io.path.readBytes
import kotlin.io.path.writeBytes import kotlin.io.path.writeBytes
import kotlin.system.exitProcess import kotlin.system.exitProcess
fun main() { fun main(args: Array<String>) {
val chosenModelName = if (args.isNotEmpty()) args[0] else null
val chosenPrompt = if (args.size >= 2) args[1] else "cat"
val chosenNegativePrompt = if (args.size >= 3) args[2] else "bad, nsfw, low quality"
val channel = ManagedChannelBuilder val channel = ManagedChannelBuilder
.forAddress("127.0.0.1", 4546) .forAddress("127.0.0.1", 4546)
.usePlaintext() .usePlaintext()
@ -32,7 +32,12 @@ fun main() {
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}") println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
} }
val model = modelListResponse.availableModelsList.random() val model = if (chosenModelName == null) {
modelListResponse.availableModelsList.random()
} else {
modelListResponse.availableModelsList.first { it.name == chosenModelName }
}
if (!model.isLoaded) { if (!model.isLoaded) {
println("loading model ${model.name}...") println("loading model ${model.name}...")
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply { client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
@ -43,20 +48,39 @@ fun main() {
println("using model ${model.name}...") println("using model ${model.name}...")
} }
println("tokenizing prompts...")
val tokenizePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
modelName = model.name
input = chosenPrompt
}.build())
val tokenizeNegativePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
modelName = model.name
input = chosenNegativePrompt
}.build())
println("tokenize prompt='${chosenPrompt}' " +
"tokens=[${tokenizePromptResponse.tokensList.joinToString(", ")}] " +
"token_ids=[${tokenizePromptResponse.tokenIdsList.joinToString(", ")}]")
println("tokenize negative_prompt='${chosenNegativePrompt}' " +
"tokens=[${tokenizeNegativePromptResponse.tokensList.joinToString(", ")}] " +
"token_ids=[${tokenizeNegativePromptResponse.tokenIdsList.joinToString(", ")}]")
println("generating images...") println("generating images...")
val startingImagePath = Path("work/start.png") val startingImagePath = Path("work/start.png")
val request = GenerateImagesRequest.newBuilder().apply { val request = GenerateImagesRequest.newBuilder().apply {
modelName = model.name modelName = model.name
outputImageFormat = StableDiffusion.ImageFormat.png outputImageFormat = ImageFormat.png
batchSize = 2 batchSize = 2
batchCount = 2 batchCount = 2
prompt = "cat" prompt = chosenPrompt
negativePrompt = "bad, low quality, nsfw" negativePrompt = chosenNegativePrompt
if (startingImagePath.exists()) { if (startingImagePath.exists()) {
val image = Image.newBuilder().apply { val image = Image.newBuilder().apply {
format = StableDiffusion.ImageFormat.png format = ImageFormat.png
data = ByteString.copyFrom(startingImagePath.readBytes()) data = ByteString.copyFrom(startingImagePath.readBytes())
}.build() }.build()
@ -65,10 +89,12 @@ fun main() {
}.build() }.build()
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) { for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
if (update.hasBatchProgress()) { if (update.hasBatchProgress()) {
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%") println("batch=${update.currentBatch} " +
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
for ((index, image) in update.batchProgress.imagesList.withIndex()) { for ((index, image) in update.batchProgress.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)") println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}") val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray()) path.writeBytes(image.data.toByteArray())
} }
@ -77,13 +103,14 @@ fun main() {
if (update.hasBatchCompleted()) { if (update.hasBatchCompleted()) {
for ((index, image) in update.batchCompleted.imagesList.withIndex()) { for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1) val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)") println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
val path = Path("work/final_${imageIndex}.${image.format.name}") val path = Path("work/final_${imageIndex}.${image.format.name}")
path.writeBytes(image.data.toByteArray()) path.writeBytes(image.data.toByteArray())
} }
} }
println("overall progress ${update.overallPercentageComplete}%")
} }
channel.shutdownNow() channel.shutdownNow()
} }
fun prettyProgressValue(value: Float) = String.format("%.2f", value)

View File

@ -2,7 +2,7 @@ package gay.pizza.stable.diffusion
import io.grpc.Channel import io.grpc.Channel
@Suppress("MemberVisibilityCanBePrivate") @Suppress("MemberVisibilityCanBePrivate", "unused")
class StableDiffusionRpcClient(val channel: Channel) { class StableDiffusionRpcClient(val channel: Channel) {
val modelService: ModelServiceGrpc.ModelServiceStub by lazy { val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
ModelServiceGrpc.newStub(channel) ModelServiceGrpc.newStub(channel)
@ -35,4 +35,20 @@ class StableDiffusionRpcClient(val channel: Channel) {
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy { val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel) ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
} }
val tokenizerService: TokenizerServiceGrpc.TokenizerServiceStub by lazy {
TokenizerServiceGrpc.newStub(channel)
}
val tokenizerServiceBlocking: TokenizerServiceGrpc.TokenizerServiceBlockingStub by lazy {
TokenizerServiceGrpc.newBlockingStub(channel)
}
val tokenizerServiceFuture: TokenizerServiceGrpc.TokenizerServiceFutureStub by lazy {
TokenizerServiceGrpc.newFutureStub(channel)
}
val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy {
TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel)
}
} }

View File

@ -350,3 +350,43 @@ service ImageGenerationService {
*/ */
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate); rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
} }
/**
* Represents a request to tokenize an input.
*/
message TokenizeRequest {
/**
* The name of a loaded model to use for tokenization.
*/
string model_name = 1;
/**
* The input string to tokenize.
*/
string input = 2;
}
/**
* Represents a response to tokenization.
*/
message TokenizeResponse {
/**
* The tokens inside the input string.
*/
repeated string tokens = 1;
/**
* The token IDs inside the input string.
*/
repeated uint64 token_ids = 2;
}
/**
* The tokenizer service, for analyzing tokens for a loaded model.
*/
service TokenizerService {
/**
* Analyze the input using a loaded model and return the results.
*/
rpc Tokenize(TokenizeRequest) returns (TokenizeResponse);
}

View File

@ -154,4 +154,15 @@ public actor ModelState {
} }
return pipelineConfig return pipelineConfig
} }
public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse {
guard let tokenizer else {
throw SdCoreError.modelNotLoaded
}
let results = tokenizer.tokenize(input: request.input)
var response = SdTokenizeResponse()
response.tokens = results.tokens
response.tokenIds = results.tokenIDs.map { UInt64($0) }
return response
}
} }

View File

@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata {
} }
} }
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls.
public protocol SdTokenizerServiceClientProtocol: GRPCClient {
var serviceName: String { get }
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions?
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse>
}
extension SdTokenizerServiceClientProtocol {
public var serviceName: String {
return "gay.pizza.stable.diffusion.TokenizerService"
}
///*
/// Analyze the input using a loaded model and return the results.
///
/// - Parameters:
/// - request: Request to send to Tokenize.
/// - callOptions: Call options.
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
public func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
return self.makeUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
#if compiler(>=5.6)
@available(*, deprecated)
extension SdTokenizerServiceClient: @unchecked Sendable {}
#endif // compiler(>=5.6)
@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient")
public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol {
private let lock = Lock()
private var _defaultCallOptions: CallOptions
private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
public let channel: GRPCChannel
public var defaultCallOptions: CallOptions {
get { self.lock.withLock { return self._defaultCallOptions } }
set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }
}
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
get { self.lock.withLock { return self._interceptors } }
set { self.lock.withLockVoid { self._interceptors = newValue } }
}
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
///
/// - Parameters:
/// - channel: `GRPCChannel` to the service host.
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
/// - interceptors: A factory providing interceptors for each RPC.
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self._defaultCallOptions = defaultCallOptions
self._interceptors = interceptors
}
}
public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol {
public var channel: GRPCChannel
public var defaultCallOptions: CallOptions
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
///
/// - Parameters:
/// - channel: `GRPCChannel` to the service host.
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
/// - interceptors: A factory providing interceptors for each RPC.
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self.defaultCallOptions = defaultCallOptions
self.interceptors = interceptors
}
}
#if compiler(>=5.6)
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient {
static var serviceDescriptor: GRPCServiceDescriptor { get }
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
func makeTokenizeCall(
_ request: SdTokenizeRequest,
callOptions: CallOptions?
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse>
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncClientProtocol {
public static var serviceDescriptor: GRPCServiceDescriptor {
return SdTokenizerServiceClientMetadata.serviceDescriptor
}
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
return nil
}
public func makeTokenizeCall(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
return self.makeAsyncUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncClientProtocol {
public func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) async throws -> SdTokenizeResponse {
return try await self.performAsyncUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol {
public var channel: GRPCChannel
public var defaultCallOptions: CallOptions
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self.defaultCallOptions = defaultCallOptions
self.interceptors = interceptors
}
}
#endif // compiler(>=5.6)
public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable {
/// - Returns: Interceptors to use when invoking 'tokenize'.
func makeTokenizeInterceptors() -> [ClientInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
}
public enum SdTokenizerServiceClientMetadata {
public static let serviceDescriptor = GRPCServiceDescriptor(
name: "TokenizerService",
fullName: "gay.pizza.stable.diffusion.TokenizerService",
methods: [
SdTokenizerServiceClientMetadata.Methods.tokenize,
]
)
public enum Methods {
public static let tokenize = GRPCMethodDescriptor(
name: "Tokenize",
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
type: GRPCCallType.unary
)
}
}
///* ///*
/// The model service, for management and loading of models. /// The model service, for management and loading of models.
/// ///
@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata {
) )
} }
} }
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// To build a server, implement a class that conforms to this protocol.
public protocol SdTokenizerServiceProvider: CallHandlerProvider {
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
///*
/// Analyze the input using a loaded model and return the results.
func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdTokenizeResponse>
}
extension SdTokenizerServiceProvider {
public var serviceName: Substring {
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
}
/// Determines, calls and returns the appropriate request handler, depending on the request's method.
/// Returns nil for methods not handled by this service.
public func handle(
method name: Substring,
context: CallHandlerContext
) -> GRPCServerHandlerProtocol? {
switch name {
case "Tokenize":
return UnaryServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
userFunction: self.tokenize(request:context:)
)
default:
return nil
}
}
}
#if compiler(>=5.6)
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// To implement a server, implement an object which conforms to this protocol.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider {
static var serviceDescriptor: GRPCServiceDescriptor { get }
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
///*
/// Analyze the input using a loaded model and return the results.
@Sendable func tokenize(
request: SdTokenizeRequest,
context: GRPCAsyncServerCallContext
) async throws -> SdTokenizeResponse
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncProvider {
public static var serviceDescriptor: GRPCServiceDescriptor {
return SdTokenizerServiceServerMetadata.serviceDescriptor
}
public var serviceName: Substring {
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
}
public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? {
return nil
}
public func handle(
method name: Substring,
context: CallHandlerContext
) -> GRPCServerHandlerProtocol? {
switch name {
case "Tokenize":
return GRPCAsyncServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
wrapping: self.tokenize(request:context:)
)
default:
return nil
}
}
}
#endif // compiler(>=5.6)
public protocol SdTokenizerServiceServerInterceptorFactoryProtocol {
/// - Returns: Interceptors to use when handling 'tokenize'.
/// Defaults to calling `self.makeInterceptors()`.
func makeTokenizeInterceptors() -> [ServerInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
}
public enum SdTokenizerServiceServerMetadata {
public static let serviceDescriptor = GRPCServiceDescriptor(
name: "TokenizerService",
fullName: "gay.pizza.stable.diffusion.TokenizerService",
methods: [
SdTokenizerServiceServerMetadata.Methods.tokenize,
]
)
public enum Methods {
public static let tokenize = GRPCMethodDescriptor(
name: "Tokenize",
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
type: GRPCCallType.unary
)
}
}

View File

@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate {
public init() {} public init() {}
} }
///*
/// Represents a request to tokenize an input.
public struct SdTokenizeRequest {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The name of a loaded model to use for tokenization.
public var modelName: String = String()
///*
/// The input string to tokenize.
public var input: String = String()
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
///*
/// Represents a response to tokenization.
public struct SdTokenizeResponse {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The tokens inside the input string.
public var tokens: [String] = []
///*
/// The token IDs inside the input string.
public var tokenIds: [UInt64] = []
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
#if swift(>=5.5) && canImport(_Concurrency) #if swift(>=5.5) && canImport(_Concurrency)
extension SdModelAttention: @unchecked Sendable {} extension SdModelAttention: @unchecked Sendable {}
extension SdScheduler: @unchecked Sendable {} extension SdScheduler: @unchecked Sendable {}
@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {} extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {} extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {} extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
extension SdTokenizeRequest: @unchecked Sendable {}
extension SdTokenizeResponse: @unchecked Sendable {}
#endif // swift(>=5.5) && canImport(_Concurrency) #endif // swift(>=5.5) && canImport(_Concurrency)
// MARK: - Code below here is support for the SwiftProtobuf runtime. // MARK: - Code below here is support for the SwiftProtobuf runtime.
@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me
return true return true
} }
} }
extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "model_name"),
2: .same(proto: "input"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
case 2: try { try decoder.decodeSingularStringField(value: &self.input) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if !self.modelName.isEmpty {
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
}
if !self.input.isEmpty {
try visitor.visitSingularStringField(value: self.input, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool {
if lhs.modelName != rhs.modelName {return false}
if lhs.input != rhs.input {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}
extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "tokens"),
2: .standard(proto: "token_ids"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }()
case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if !self.tokens.isEmpty {
try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1)
}
if !self.tokenIds.isEmpty {
try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool {
if lhs.tokens != rhs.tokens {return false}
if lhs.tokenIds != rhs.tokenIds {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}

View File

@ -0,0 +1,19 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.tokenize(request)
}
}

View File

@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand {
_ = Server.insecure(group: group) _ = Server.insecure(group: group)
.withServiceProviders([ .withServiceProviders([
ModelServiceProvider(modelManager: modelManager), ModelServiceProvider(modelManager: modelManager),
ImageGenerationServiceProvider(modelManager: modelManager) ImageGenerationServiceProvider(modelManager: modelManager),
TokenizerServiceProvider(modelManager: modelManager)
]) ])
.bind(host: bindHost, port: bindPort) .bind(host: bindHost, port: bindPort)