Add support for BPE tokenization.

This commit is contained in:
2023-04-23 18:49:52 -07:00
parent 0fe35cd976
commit f61fe6a18f
9 changed files with 562 additions and 17 deletions

View File

@ -154,4 +154,15 @@ public actor ModelState {
}
return pipelineConfig
}
public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse {
guard let tokenizer else {
throw SdCoreError.modelNotLoaded
}
let results = tokenizer.tokenize(input: request.input)
var response = SdTokenizeResponse()
response.tokens = results.tokens
response.tokenIds = results.tokenIDs.map { UInt64($0) }
return response
}
}

View File

@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata {
}
}
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls.
public protocol SdTokenizerServiceClientProtocol: GRPCClient {
var serviceName: String { get }
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions?
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse>
}
extension SdTokenizerServiceClientProtocol {
public var serviceName: String {
return "gay.pizza.stable.diffusion.TokenizerService"
}
///*
/// Analyze the input using a loaded model and return the results.
///
/// - Parameters:
/// - request: Request to send to Tokenize.
/// - callOptions: Call options.
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
public func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
return self.makeUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
#if compiler(>=5.6)
@available(*, deprecated)
extension SdTokenizerServiceClient: @unchecked Sendable {}
#endif // compiler(>=5.6)
@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient")
public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol {
private let lock = Lock()
private var _defaultCallOptions: CallOptions
private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
public let channel: GRPCChannel
public var defaultCallOptions: CallOptions {
get { self.lock.withLock { return self._defaultCallOptions } }
set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }
}
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
get { self.lock.withLock { return self._interceptors } }
set { self.lock.withLockVoid { self._interceptors = newValue } }
}
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
///
/// - Parameters:
/// - channel: `GRPCChannel` to the service host.
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
/// - interceptors: A factory providing interceptors for each RPC.
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self._defaultCallOptions = defaultCallOptions
self._interceptors = interceptors
}
}
public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol {
public var channel: GRPCChannel
public var defaultCallOptions: CallOptions
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
///
/// - Parameters:
/// - channel: `GRPCChannel` to the service host.
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
/// - interceptors: A factory providing interceptors for each RPC.
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self.defaultCallOptions = defaultCallOptions
self.interceptors = interceptors
}
}
#if compiler(>=5.6)
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient {
static var serviceDescriptor: GRPCServiceDescriptor { get }
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
func makeTokenizeCall(
_ request: SdTokenizeRequest,
callOptions: CallOptions?
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse>
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncClientProtocol {
public static var serviceDescriptor: GRPCServiceDescriptor {
return SdTokenizerServiceClientMetadata.serviceDescriptor
}
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
return nil
}
public func makeTokenizeCall(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
return self.makeAsyncUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncClientProtocol {
public func tokenize(
_ request: SdTokenizeRequest,
callOptions: CallOptions? = nil
) async throws -> SdTokenizeResponse {
return try await self.performAsyncUnaryCall(
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
request: request,
callOptions: callOptions ?? self.defaultCallOptions,
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
)
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol {
public var channel: GRPCChannel
public var defaultCallOptions: CallOptions
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
public init(
channel: GRPCChannel,
defaultCallOptions: CallOptions = CallOptions(),
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
) {
self.channel = channel
self.defaultCallOptions = defaultCallOptions
self.interceptors = interceptors
}
}
#endif // compiler(>=5.6)
public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable {
/// - Returns: Interceptors to use when invoking 'tokenize'.
func makeTokenizeInterceptors() -> [ClientInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
}
public enum SdTokenizerServiceClientMetadata {
public static let serviceDescriptor = GRPCServiceDescriptor(
name: "TokenizerService",
fullName: "gay.pizza.stable.diffusion.TokenizerService",
methods: [
SdTokenizerServiceClientMetadata.Methods.tokenize,
]
)
public enum Methods {
public static let tokenize = GRPCMethodDescriptor(
name: "Tokenize",
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
type: GRPCCallType.unary
)
}
}
///*
/// The model service, for management and loading of models.
///
@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata {
)
}
}
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// To build a server, implement a class that conforms to this protocol.
public protocol SdTokenizerServiceProvider: CallHandlerProvider {
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
///*
/// Analyze the input using a loaded model and return the results.
func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdTokenizeResponse>
}
extension SdTokenizerServiceProvider {
public var serviceName: Substring {
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
}
/// Determines, calls and returns the appropriate request handler, depending on the request's method.
/// Returns nil for methods not handled by this service.
public func handle(
method name: Substring,
context: CallHandlerContext
) -> GRPCServerHandlerProtocol? {
switch name {
case "Tokenize":
return UnaryServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
userFunction: self.tokenize(request:context:)
)
default:
return nil
}
}
}
#if compiler(>=5.6)
///*
/// The tokenizer service, for analyzing tokens for a loaded model.
///
/// To implement a server, implement an object which conforms to this protocol.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider {
static var serviceDescriptor: GRPCServiceDescriptor { get }
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
///*
/// Analyze the input using a loaded model and return the results.
@Sendable func tokenize(
request: SdTokenizeRequest,
context: GRPCAsyncServerCallContext
) async throws -> SdTokenizeResponse
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension SdTokenizerServiceAsyncProvider {
public static var serviceDescriptor: GRPCServiceDescriptor {
return SdTokenizerServiceServerMetadata.serviceDescriptor
}
public var serviceName: Substring {
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
}
public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? {
return nil
}
public func handle(
method name: Substring,
context: CallHandlerContext
) -> GRPCServerHandlerProtocol? {
switch name {
case "Tokenize":
return GRPCAsyncServerHandler(
context: context,
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
wrapping: self.tokenize(request:context:)
)
default:
return nil
}
}
}
#endif // compiler(>=5.6)
public protocol SdTokenizerServiceServerInterceptorFactoryProtocol {
/// - Returns: Interceptors to use when handling 'tokenize'.
/// Defaults to calling `self.makeInterceptors()`.
func makeTokenizeInterceptors() -> [ServerInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
}
public enum SdTokenizerServiceServerMetadata {
public static let serviceDescriptor = GRPCServiceDescriptor(
name: "TokenizerService",
fullName: "gay.pizza.stable.diffusion.TokenizerService",
methods: [
SdTokenizerServiceServerMetadata.Methods.tokenize,
]
)
public enum Methods {
public static let tokenize = GRPCMethodDescriptor(
name: "Tokenize",
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
type: GRPCCallType.unary
)
}
}

View File

@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate {
public init() {}
}
///*
/// Represents a request to tokenize an input.
public struct SdTokenizeRequest {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The name of a loaded model to use for tokenization.
public var modelName: String = String()
///*
/// The input string to tokenize.
public var input: String = String()
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
///*
/// Represents a response to tokenization.
public struct SdTokenizeResponse {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
///*
/// The tokens inside the input string.
public var tokens: [String] = []
///*
/// The token IDs inside the input string.
public var tokenIds: [UInt64] = []
public var unknownFields = SwiftProtobuf.UnknownStorage()
public init() {}
}
#if swift(>=5.5) && canImport(_Concurrency)
extension SdModelAttention: @unchecked Sendable {}
extension SdScheduler: @unchecked Sendable {}
@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
extension SdTokenizeRequest: @unchecked Sendable {}
extension SdTokenizeResponse: @unchecked Sendable {}
#endif // swift(>=5.5) && canImport(_Concurrency)
// MARK: - Code below here is support for the SwiftProtobuf runtime.
@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me
return true
}
}
extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "model_name"),
2: .same(proto: "input"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
case 2: try { try decoder.decodeSingularStringField(value: &self.input) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if !self.modelName.isEmpty {
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
}
if !self.input.isEmpty {
try visitor.visitSingularStringField(value: self.input, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool {
if lhs.modelName != rhs.modelName {return false}
if lhs.input != rhs.input {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}
extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse"
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "tokens"),
2: .standard(proto: "token_ids"),
]
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }()
case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }()
default: break
}
}
}
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if !self.tokens.isEmpty {
try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1)
}
if !self.tokenIds.isEmpty {
try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2)
}
try unknownFields.traverse(visitor: &visitor)
}
public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool {
if lhs.tokens != rhs.tokens {return false}
if lhs.tokenIds != rhs.tokenIds {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}

View File

@ -0,0 +1,19 @@
import Foundation
import GRPC
import StableDiffusionCore
import StableDiffusionProtos
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
private let modelManager: ModelManager
init(modelManager: ModelManager) {
self.modelManager = modelManager
}
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
guard let state = await modelManager.getModelState(name: request.modelName) else {
throw SdCoreError.modelNotFound
}
return try await state.tokenize(request)
}
}

View File

@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand {
_ = Server.insecure(group: group)
.withServiceProviders([
ModelServiceProvider(modelManager: modelManager),
ImageGenerationServiceProvider(modelManager: modelManager)
ImageGenerationServiceProvider(modelManager: modelManager),
TokenizerServiceProvider(modelManager: modelManager)
])
.bind(host: bindHost, port: bindPort)