mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-05 06:21:31 +00:00
Add support for BPE tokenization.
This commit is contained in:
@ -154,4 +154,15 @@ public actor ModelState {
|
||||
}
|
||||
return pipelineConfig
|
||||
}
|
||||
|
||||
public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse {
|
||||
guard let tokenizer else {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
let results = tokenizer.tokenize(input: request.input)
|
||||
var response = SdTokenizeResponse()
|
||||
response.tokens = results.tokens
|
||||
response.tokenIds = results.tokenIDs.map { UInt64($0) }
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls.
|
||||
public protocol SdTokenizerServiceClientProtocol: GRPCClient {
|
||||
var serviceName: String { get }
|
||||
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||
|
||||
func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||
}
|
||||
|
||||
extension SdTokenizerServiceClientProtocol {
|
||||
public var serviceName: String {
|
||||
return "gay.pizza.stable.diffusion.TokenizerService"
|
||||
}
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - request: Request to send to Tokenize.
|
||||
/// - callOptions: Call options.
|
||||
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
||||
public func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||
return self.makeUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
@available(*, deprecated)
|
||||
extension SdTokenizerServiceClient: @unchecked Sendable {}
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient")
|
||||
public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol {
|
||||
private let lock = Lock()
|
||||
private var _defaultCallOptions: CallOptions
|
||||
private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
public let channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions {
|
||||
get { self.lock.withLock { return self._defaultCallOptions } }
|
||||
set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }
|
||||
}
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||
get { self.lock.withLock { return self._interceptors } }
|
||||
set { self.lock.withLockVoid { self._interceptors = newValue } }
|
||||
}
|
||||
|
||||
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - channel: `GRPCChannel` to the service host.
|
||||
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||
/// - interceptors: A factory providing interceptors for each RPC.
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self._defaultCallOptions = defaultCallOptions
|
||||
self._interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol {
|
||||
public var channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
|
||||
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - channel: `GRPCChannel` to the service host.
|
||||
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||
/// - interceptors: A factory providing interceptors for each RPC.
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self.defaultCallOptions = defaultCallOptions
|
||||
self.interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient {
|
||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||
|
||||
func makeTokenizeCall(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncClientProtocol {
|
||||
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||
return SdTokenizerServiceClientMetadata.serviceDescriptor
|
||||
}
|
||||
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||
return nil
|
||||
}
|
||||
|
||||
public func makeTokenizeCall(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||
return self.makeAsyncUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncClientProtocol {
|
||||
public func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) async throws -> SdTokenizeResponse {
|
||||
return try await self.performAsyncUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol {
|
||||
public var channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self.defaultCallOptions = defaultCallOptions
|
||||
self.interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
||||
|
||||
/// - Returns: Interceptors to use when invoking 'tokenize'.
|
||||
func makeTokenizeInterceptors() -> [ClientInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||
}
|
||||
|
||||
public enum SdTokenizerServiceClientMetadata {
|
||||
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||
name: "TokenizerService",
|
||||
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||
methods: [
|
||||
SdTokenizerServiceClientMetadata.Methods.tokenize,
|
||||
]
|
||||
)
|
||||
|
||||
public enum Methods {
|
||||
public static let tokenize = GRPCMethodDescriptor(
|
||||
name: "Tokenize",
|
||||
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
///*
|
||||
/// The model service, for management and loading of models.
|
||||
///
|
||||
@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata {
|
||||
)
|
||||
}
|
||||
}
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// To build a server, implement a class that conforms to this protocol.
|
||||
public protocol SdTokenizerServiceProvider: CallHandlerProvider {
|
||||
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdTokenizeResponse>
|
||||
}
|
||||
|
||||
extension SdTokenizerServiceProvider {
|
||||
public var serviceName: Substring {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||
}
|
||||
|
||||
/// Determines, calls and returns the appropriate request handler, depending on the request's method.
|
||||
/// Returns nil for methods not handled by this service.
|
||||
public func handle(
|
||||
method name: Substring,
|
||||
context: CallHandlerContext
|
||||
) -> GRPCServerHandlerProtocol? {
|
||||
switch name {
|
||||
case "Tokenize":
|
||||
return UnaryServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||
userFunction: self.tokenize(request:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// To implement a server, implement an object which conforms to this protocol.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider {
|
||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
@Sendable func tokenize(
|
||||
request: SdTokenizeRequest,
|
||||
context: GRPCAsyncServerCallContext
|
||||
) async throws -> SdTokenizeResponse
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncProvider {
|
||||
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor
|
||||
}
|
||||
|
||||
public var serviceName: Substring {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||
}
|
||||
|
||||
public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? {
|
||||
return nil
|
||||
}
|
||||
|
||||
public func handle(
|
||||
method name: Substring,
|
||||
context: CallHandlerContext
|
||||
) -> GRPCServerHandlerProtocol? {
|
||||
switch name {
|
||||
case "Tokenize":
|
||||
return GRPCAsyncServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||
wrapping: self.tokenize(request:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
public protocol SdTokenizerServiceServerInterceptorFactoryProtocol {
|
||||
|
||||
/// - Returns: Interceptors to use when handling 'tokenize'.
|
||||
/// Defaults to calling `self.makeInterceptors()`.
|
||||
func makeTokenizeInterceptors() -> [ServerInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||
}
|
||||
|
||||
public enum SdTokenizerServiceServerMetadata {
|
||||
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||
name: "TokenizerService",
|
||||
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||
methods: [
|
||||
SdTokenizerServiceServerMetadata.Methods.tokenize,
|
||||
]
|
||||
)
|
||||
|
||||
public enum Methods {
|
||||
public static let tokenize = GRPCMethodDescriptor(
|
||||
name: "Tokenize",
|
||||
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate {
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a request to tokenize an input.
|
||||
public struct SdTokenizeRequest {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The name of a loaded model to use for tokenization.
|
||||
public var modelName: String = String()
|
||||
|
||||
///*
|
||||
/// The input string to tokenize.
|
||||
public var input: String = String()
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a response to tokenization.
|
||||
public struct SdTokenizeResponse {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The tokens inside the input string.
|
||||
public var tokens: [String] = []
|
||||
|
||||
///*
|
||||
/// The token IDs inside the input string.
|
||||
public var tokenIds: [UInt64] = []
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
#if swift(>=5.5) && canImport(_Concurrency)
|
||||
extension SdModelAttention: @unchecked Sendable {}
|
||||
extension SdScheduler: @unchecked Sendable {}
|
||||
@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
|
||||
extension SdTokenizeRequest: @unchecked Sendable {}
|
||||
extension SdTokenizeResponse: @unchecked Sendable {}
|
||||
#endif // swift(>=5.5) && canImport(_Concurrency)
|
||||
|
||||
// MARK: - Code below here is support for the SwiftProtobuf runtime.
|
||||
@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .standard(proto: "model_name"),
|
||||
2: .same(proto: "input"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
||||
case 2: try { try decoder.decodeSingularStringField(value: &self.input) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if !self.modelName.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
||||
}
|
||||
if !self.input.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.input, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool {
|
||||
if lhs.modelName != rhs.modelName {return false}
|
||||
if lhs.input != rhs.input {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .same(proto: "tokens"),
|
||||
2: .standard(proto: "token_ids"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }()
|
||||
case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if !self.tokens.isEmpty {
|
||||
try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1)
|
||||
}
|
||||
if !self.tokenIds.isEmpty {
|
||||
try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool {
|
||||
if lhs.tokens != rhs.tokens {return false}
|
||||
if lhs.tokenIds != rhs.tokenIds {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
@ -0,0 +1,19 @@
|
||||
import Foundation
|
||||
import GRPC
|
||||
import StableDiffusionCore
|
||||
import StableDiffusionProtos
|
||||
|
||||
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
|
||||
private let modelManager: ModelManager
|
||||
|
||||
init(modelManager: ModelManager) {
|
||||
self.modelManager = modelManager
|
||||
}
|
||||
|
||||
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
return try await state.tokenize(request)
|
||||
}
|
||||
}
|
@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand {
|
||||
_ = Server.insecure(group: group)
|
||||
.withServiceProviders([
|
||||
ModelServiceProvider(modelManager: modelManager),
|
||||
ImageGenerationServiceProvider(modelManager: modelManager)
|
||||
ImageGenerationServiceProvider(modelManager: modelManager),
|
||||
TokenizerServiceProvider(modelManager: modelManager)
|
||||
])
|
||||
.bind(host: bindHost, port: bindPort)
|
||||
|
||||
|
Reference in New Issue
Block a user