mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
Add support for BPE tokenization.
This commit is contained in:
@ -35,6 +35,8 @@ dependencies {
|
||||
api("io.grpc:grpc-stub:1.54.1")
|
||||
api("io.grpc:grpc-protobuf:1.54.1")
|
||||
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
||||
implementation("com.google.protobuf:protobuf-java:3.22.3")
|
||||
|
||||
implementation("io.grpc:grpc-netty:1.54.1")
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,7 @@
|
||||
package gay.pizza.stable.diffusion.sample
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import gay.pizza.stable.diffusion.StableDiffusion
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.Image
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
||||
import gay.pizza.stable.diffusion.StableDiffusion.*
|
||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import kotlin.io.path.Path
|
||||
@ -14,7 +10,11 @@ import kotlin.io.path.readBytes
|
||||
import kotlin.io.path.writeBytes
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
fun main() {
|
||||
fun main(args: Array<String>) {
|
||||
val chosenModelName = if (args.isNotEmpty()) args[0] else null
|
||||
val chosenPrompt = if (args.size >= 2) args[1] else "cat"
|
||||
val chosenNegativePrompt = if (args.size >= 3) args[2] else "bad, nsfw, low quality"
|
||||
|
||||
val channel = ManagedChannelBuilder
|
||||
.forAddress("127.0.0.1", 4546)
|
||||
.usePlaintext()
|
||||
@ -32,7 +32,12 @@ fun main() {
|
||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||
}
|
||||
|
||||
val model = modelListResponse.availableModelsList.random()
|
||||
val model = if (chosenModelName == null) {
|
||||
modelListResponse.availableModelsList.random()
|
||||
} else {
|
||||
modelListResponse.availableModelsList.first { it.name == chosenModelName }
|
||||
}
|
||||
|
||||
if (!model.isLoaded) {
|
||||
println("loading model ${model.name}...")
|
||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||
@ -43,20 +48,39 @@ fun main() {
|
||||
println("using model ${model.name}...")
|
||||
}
|
||||
|
||||
println("tokenizing prompts...")
|
||||
|
||||
val tokenizePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
input = chosenPrompt
|
||||
}.build())
|
||||
val tokenizeNegativePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
input = chosenNegativePrompt
|
||||
}.build())
|
||||
|
||||
println("tokenize prompt='${chosenPrompt}' " +
|
||||
"tokens=[${tokenizePromptResponse.tokensList.joinToString(", ")}] " +
|
||||
"token_ids=[${tokenizePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||
|
||||
println("tokenize negative_prompt='${chosenNegativePrompt}' " +
|
||||
"tokens=[${tokenizeNegativePromptResponse.tokensList.joinToString(", ")}] " +
|
||||
"token_ids=[${tokenizeNegativePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||
|
||||
println("generating images...")
|
||||
|
||||
val startingImagePath = Path("work/start.png")
|
||||
|
||||
val request = GenerateImagesRequest.newBuilder().apply {
|
||||
modelName = model.name
|
||||
outputImageFormat = StableDiffusion.ImageFormat.png
|
||||
outputImageFormat = ImageFormat.png
|
||||
batchSize = 2
|
||||
batchCount = 2
|
||||
prompt = "cat"
|
||||
negativePrompt = "bad, low quality, nsfw"
|
||||
prompt = chosenPrompt
|
||||
negativePrompt = chosenNegativePrompt
|
||||
if (startingImagePath.exists()) {
|
||||
val image = Image.newBuilder().apply {
|
||||
format = StableDiffusion.ImageFormat.png
|
||||
format = ImageFormat.png
|
||||
data = ByteString.copyFrom(startingImagePath.readBytes())
|
||||
}.build()
|
||||
|
||||
@ -65,10 +89,12 @@ fun main() {
|
||||
}.build()
|
||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||
if (update.hasBatchProgress()) {
|
||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
||||
println("batch=${update.currentBatch} " +
|
||||
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
||||
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
@ -77,13 +103,14 @@ fun main() {
|
||||
if (update.hasBatchCompleted()) {
|
||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||
path.writeBytes(image.data.toByteArray())
|
||||
}
|
||||
}
|
||||
println("overall progress ${update.overallPercentageComplete}%")
|
||||
}
|
||||
|
||||
channel.shutdownNow()
|
||||
}
|
||||
|
||||
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
||||
|
@ -2,7 +2,7 @@ package gay.pizza.stable.diffusion
|
||||
|
||||
import io.grpc.Channel
|
||||
|
||||
@Suppress("MemberVisibilityCanBePrivate")
|
||||
@Suppress("MemberVisibilityCanBePrivate", "unused")
|
||||
class StableDiffusionRpcClient(val channel: Channel) {
|
||||
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
|
||||
ModelServiceGrpc.newStub(channel)
|
||||
@ -35,4 +35,20 @@ class StableDiffusionRpcClient(val channel: Channel) {
|
||||
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
|
||||
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerService: TokenizerServiceGrpc.TokenizerServiceStub by lazy {
|
||||
TokenizerServiceGrpc.newStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceBlocking: TokenizerServiceGrpc.TokenizerServiceBlockingStub by lazy {
|
||||
TokenizerServiceGrpc.newBlockingStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceFuture: TokenizerServiceGrpc.TokenizerServiceFutureStub by lazy {
|
||||
TokenizerServiceGrpc.newFutureStub(channel)
|
||||
}
|
||||
|
||||
val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy {
|
||||
TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel)
|
||||
}
|
||||
}
|
||||
|
@ -350,3 +350,43 @@ service ImageGenerationService {
|
||||
*/
|
||||
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a request to tokenize an input.
|
||||
*/
|
||||
message TokenizeRequest {
|
||||
/**
|
||||
* The name of a loaded model to use for tokenization.
|
||||
*/
|
||||
string model_name = 1;
|
||||
|
||||
/**
|
||||
* The input string to tokenize.
|
||||
*/
|
||||
string input = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a response to tokenization.
|
||||
*/
|
||||
message TokenizeResponse {
|
||||
/**
|
||||
* The tokens inside the input string.
|
||||
*/
|
||||
repeated string tokens = 1;
|
||||
|
||||
/**
|
||||
* The token IDs inside the input string.
|
||||
*/
|
||||
repeated uint64 token_ids = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* The tokenizer service, for analyzing tokens for a loaded model.
|
||||
*/
|
||||
service TokenizerService {
|
||||
/**
|
||||
* Analyze the input using a loaded model and return the results.
|
||||
*/
|
||||
rpc Tokenize(TokenizeRequest) returns (TokenizeResponse);
|
||||
}
|
||||
|
@ -154,4 +154,15 @@ public actor ModelState {
|
||||
}
|
||||
return pipelineConfig
|
||||
}
|
||||
|
||||
public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse {
|
||||
guard let tokenizer else {
|
||||
throw SdCoreError.modelNotLoaded
|
||||
}
|
||||
let results = tokenizer.tokenize(input: request.input)
|
||||
var response = SdTokenizeResponse()
|
||||
response.tokens = results.tokens
|
||||
response.tokenIds = results.tokenIDs.map { UInt64($0) }
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls.
|
||||
public protocol SdTokenizerServiceClientProtocol: GRPCClient {
|
||||
var serviceName: String { get }
|
||||
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||
|
||||
func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||
}
|
||||
|
||||
extension SdTokenizerServiceClientProtocol {
|
||||
public var serviceName: String {
|
||||
return "gay.pizza.stable.diffusion.TokenizerService"
|
||||
}
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - request: Request to send to Tokenize.
|
||||
/// - callOptions: Call options.
|
||||
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
||||
public func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||
return self.makeUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
@available(*, deprecated)
|
||||
extension SdTokenizerServiceClient: @unchecked Sendable {}
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient")
|
||||
public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol {
|
||||
private let lock = Lock()
|
||||
private var _defaultCallOptions: CallOptions
|
||||
private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
public let channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions {
|
||||
get { self.lock.withLock { return self._defaultCallOptions } }
|
||||
set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }
|
||||
}
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||
get { self.lock.withLock { return self._interceptors } }
|
||||
set { self.lock.withLockVoid { self._interceptors = newValue } }
|
||||
}
|
||||
|
||||
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - channel: `GRPCChannel` to the service host.
|
||||
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||
/// - interceptors: A factory providing interceptors for each RPC.
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self._defaultCallOptions = defaultCallOptions
|
||||
self._interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol {
|
||||
public var channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
|
||||
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - channel: `GRPCChannel` to the service host.
|
||||
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||
/// - interceptors: A factory providing interceptors for each RPC.
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self.defaultCallOptions = defaultCallOptions
|
||||
self.interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient {
|
||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||
|
||||
func makeTokenizeCall(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions?
|
||||
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncClientProtocol {
|
||||
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||
return SdTokenizerServiceClientMetadata.serviceDescriptor
|
||||
}
|
||||
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||
return nil
|
||||
}
|
||||
|
||||
public func makeTokenizeCall(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||
return self.makeAsyncUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncClientProtocol {
|
||||
public func tokenize(
|
||||
_ request: SdTokenizeRequest,
|
||||
callOptions: CallOptions? = nil
|
||||
) async throws -> SdTokenizeResponse {
|
||||
return try await self.performAsyncUnaryCall(
|
||||
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||
request: request,
|
||||
callOptions: callOptions ?? self.defaultCallOptions,
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol {
|
||||
public var channel: GRPCChannel
|
||||
public var defaultCallOptions: CallOptions
|
||||
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||
|
||||
public init(
|
||||
channel: GRPCChannel,
|
||||
defaultCallOptions: CallOptions = CallOptions(),
|
||||
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||
) {
|
||||
self.channel = channel
|
||||
self.defaultCallOptions = defaultCallOptions
|
||||
self.interceptors = interceptors
|
||||
}
|
||||
}
|
||||
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
||||
|
||||
/// - Returns: Interceptors to use when invoking 'tokenize'.
|
||||
func makeTokenizeInterceptors() -> [ClientInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||
}
|
||||
|
||||
public enum SdTokenizerServiceClientMetadata {
|
||||
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||
name: "TokenizerService",
|
||||
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||
methods: [
|
||||
SdTokenizerServiceClientMetadata.Methods.tokenize,
|
||||
]
|
||||
)
|
||||
|
||||
public enum Methods {
|
||||
public static let tokenize = GRPCMethodDescriptor(
|
||||
name: "Tokenize",
|
||||
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
///*
|
||||
/// The model service, for management and loading of models.
|
||||
///
|
||||
@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata {
|
||||
)
|
||||
}
|
||||
}
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// To build a server, implement a class that conforms to this protocol.
|
||||
public protocol SdTokenizerServiceProvider: CallHandlerProvider {
|
||||
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdTokenizeResponse>
|
||||
}
|
||||
|
||||
extension SdTokenizerServiceProvider {
|
||||
public var serviceName: Substring {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||
}
|
||||
|
||||
/// Determines, calls and returns the appropriate request handler, depending on the request's method.
|
||||
/// Returns nil for methods not handled by this service.
|
||||
public func handle(
|
||||
method name: Substring,
|
||||
context: CallHandlerContext
|
||||
) -> GRPCServerHandlerProtocol? {
|
||||
switch name {
|
||||
case "Tokenize":
|
||||
return UnaryServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||
userFunction: self.tokenize(request:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.6)
|
||||
|
||||
///*
|
||||
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||
///
|
||||
/// To implement a server, implement an object which conforms to this protocol.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider {
|
||||
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||
|
||||
///*
|
||||
/// Analyze the input using a loaded model and return the results.
|
||||
@Sendable func tokenize(
|
||||
request: SdTokenizeRequest,
|
||||
context: GRPCAsyncServerCallContext
|
||||
) async throws -> SdTokenizeResponse
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension SdTokenizerServiceAsyncProvider {
|
||||
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor
|
||||
}
|
||||
|
||||
public var serviceName: Substring {
|
||||
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||
}
|
||||
|
||||
public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? {
|
||||
return nil
|
||||
}
|
||||
|
||||
public func handle(
|
||||
method name: Substring,
|
||||
context: CallHandlerContext
|
||||
) -> GRPCServerHandlerProtocol? {
|
||||
switch name {
|
||||
case "Tokenize":
|
||||
return GRPCAsyncServerHandler(
|
||||
context: context,
|
||||
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||
wrapping: self.tokenize(request:context:)
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // compiler(>=5.6)
|
||||
|
||||
public protocol SdTokenizerServiceServerInterceptorFactoryProtocol {
|
||||
|
||||
/// - Returns: Interceptors to use when handling 'tokenize'.
|
||||
/// Defaults to calling `self.makeInterceptors()`.
|
||||
func makeTokenizeInterceptors() -> [ServerInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||
}
|
||||
|
||||
public enum SdTokenizerServiceServerMetadata {
|
||||
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||
name: "TokenizerService",
|
||||
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||
methods: [
|
||||
SdTokenizerServiceServerMetadata.Methods.tokenize,
|
||||
]
|
||||
)
|
||||
|
||||
public enum Methods {
|
||||
public static let tokenize = GRPCMethodDescriptor(
|
||||
name: "Tokenize",
|
||||
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||
type: GRPCCallType.unary
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate {
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a request to tokenize an input.
|
||||
public struct SdTokenizeRequest {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The name of a loaded model to use for tokenization.
|
||||
public var modelName: String = String()
|
||||
|
||||
///*
|
||||
/// The input string to tokenize.
|
||||
public var input: String = String()
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
///*
|
||||
/// Represents a response to tokenization.
|
||||
public struct SdTokenizeResponse {
|
||||
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||
// methods supported on all messages.
|
||||
|
||||
///*
|
||||
/// The tokens inside the input string.
|
||||
public var tokens: [String] = []
|
||||
|
||||
///*
|
||||
/// The token IDs inside the input string.
|
||||
public var tokenIds: [UInt64] = []
|
||||
|
||||
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||
|
||||
public init() {}
|
||||
}
|
||||
|
||||
#if swift(>=5.5) && canImport(_Concurrency)
|
||||
extension SdModelAttention: @unchecked Sendable {}
|
||||
extension SdScheduler: @unchecked Sendable {}
|
||||
@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
|
||||
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
|
||||
extension SdTokenizeRequest: @unchecked Sendable {}
|
||||
extension SdTokenizeResponse: @unchecked Sendable {}
|
||||
#endif // swift(>=5.5) && canImport(_Concurrency)
|
||||
|
||||
// MARK: - Code below here is support for the SwiftProtobuf runtime.
|
||||
@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .standard(proto: "model_name"),
|
||||
2: .same(proto: "input"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
||||
case 2: try { try decoder.decodeSingularStringField(value: &self.input) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if !self.modelName.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
||||
}
|
||||
if !self.input.isEmpty {
|
||||
try visitor.visitSingularStringField(value: self.input, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool {
|
||||
if lhs.modelName != rhs.modelName {return false}
|
||||
if lhs.input != rhs.input {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||
public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse"
|
||||
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||
1: .same(proto: "tokens"),
|
||||
2: .standard(proto: "token_ids"),
|
||||
]
|
||||
|
||||
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||
// The use of inline closures is to circumvent an issue where the compiler
|
||||
// allocates stack space for every case branch when no optimizations are
|
||||
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||
switch fieldNumber {
|
||||
case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }()
|
||||
case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }()
|
||||
default: break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||
if !self.tokens.isEmpty {
|
||||
try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1)
|
||||
}
|
||||
if !self.tokenIds.isEmpty {
|
||||
try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2)
|
||||
}
|
||||
try unknownFields.traverse(visitor: &visitor)
|
||||
}
|
||||
|
||||
public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool {
|
||||
if lhs.tokens != rhs.tokens {return false}
|
||||
if lhs.tokenIds != rhs.tokenIds {return false}
|
||||
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
@ -0,0 +1,19 @@
|
||||
import Foundation
|
||||
import GRPC
|
||||
import StableDiffusionCore
|
||||
import StableDiffusionProtos
|
||||
|
||||
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
|
||||
private let modelManager: ModelManager
|
||||
|
||||
init(modelManager: ModelManager) {
|
||||
self.modelManager = modelManager
|
||||
}
|
||||
|
||||
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
|
||||
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||
throw SdCoreError.modelNotFound
|
||||
}
|
||||
return try await state.tokenize(request)
|
||||
}
|
||||
}
|
@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand {
|
||||
_ = Server.insecure(group: group)
|
||||
.withServiceProviders([
|
||||
ModelServiceProvider(modelManager: modelManager),
|
||||
ImageGenerationServiceProvider(modelManager: modelManager)
|
||||
ImageGenerationServiceProvider(modelManager: modelManager),
|
||||
TokenizerServiceProvider(modelManager: modelManager)
|
||||
])
|
||||
.bind(host: bindHost, port: bindPort)
|
||||
|
||||
|
Reference in New Issue
Block a user