mirror of
https://github.com/GayPizzaSpecifications/stable-diffusion-rpc.git
synced 2025-08-03 13:31:32 +00:00
Add support for BPE tokenization.
This commit is contained in:
@ -35,6 +35,8 @@ dependencies {
|
|||||||
api("io.grpc:grpc-stub:1.54.1")
|
api("io.grpc:grpc-stub:1.54.1")
|
||||||
api("io.grpc:grpc-protobuf:1.54.1")
|
api("io.grpc:grpc-protobuf:1.54.1")
|
||||||
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
api("io.grpc:grpc-kotlin-stub:1.3.0")
|
||||||
|
implementation("com.google.protobuf:protobuf-java:3.22.3")
|
||||||
|
|
||||||
implementation("io.grpc:grpc-netty:1.54.1")
|
implementation("io.grpc:grpc-netty:1.54.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
package gay.pizza.stable.diffusion.sample
|
package gay.pizza.stable.diffusion.sample
|
||||||
|
|
||||||
import com.google.protobuf.ByteString
|
import com.google.protobuf.ByteString
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion
|
import gay.pizza.stable.diffusion.StableDiffusion.*
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.GenerateImagesRequest
|
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.Image
|
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.ListModelsRequest
|
|
||||||
import gay.pizza.stable.diffusion.StableDiffusion.LoadModelRequest
|
|
||||||
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
import gay.pizza.stable.diffusion.StableDiffusionRpcClient
|
||||||
import io.grpc.ManagedChannelBuilder
|
import io.grpc.ManagedChannelBuilder
|
||||||
import kotlin.io.path.Path
|
import kotlin.io.path.Path
|
||||||
@ -14,7 +10,11 @@ import kotlin.io.path.readBytes
|
|||||||
import kotlin.io.path.writeBytes
|
import kotlin.io.path.writeBytes
|
||||||
import kotlin.system.exitProcess
|
import kotlin.system.exitProcess
|
||||||
|
|
||||||
fun main() {
|
fun main(args: Array<String>) {
|
||||||
|
val chosenModelName = if (args.isNotEmpty()) args[0] else null
|
||||||
|
val chosenPrompt = if (args.size >= 2) args[1] else "cat"
|
||||||
|
val chosenNegativePrompt = if (args.size >= 3) args[2] else "bad, nsfw, low quality"
|
||||||
|
|
||||||
val channel = ManagedChannelBuilder
|
val channel = ManagedChannelBuilder
|
||||||
.forAddress("127.0.0.1", 4546)
|
.forAddress("127.0.0.1", 4546)
|
||||||
.usePlaintext()
|
.usePlaintext()
|
||||||
@ -32,7 +32,12 @@ fun main() {
|
|||||||
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
println(" model ${model.name} attention=${model.attention} loaded=${model.isLoaded}${maybeLoadedComputeUnits}")
|
||||||
}
|
}
|
||||||
|
|
||||||
val model = modelListResponse.availableModelsList.random()
|
val model = if (chosenModelName == null) {
|
||||||
|
modelListResponse.availableModelsList.random()
|
||||||
|
} else {
|
||||||
|
modelListResponse.availableModelsList.first { it.name == chosenModelName }
|
||||||
|
}
|
||||||
|
|
||||||
if (!model.isLoaded) {
|
if (!model.isLoaded) {
|
||||||
println("loading model ${model.name}...")
|
println("loading model ${model.name}...")
|
||||||
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
client.modelServiceBlocking.loadModel(LoadModelRequest.newBuilder().apply {
|
||||||
@ -43,20 +48,39 @@ fun main() {
|
|||||||
println("using model ${model.name}...")
|
println("using model ${model.name}...")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
println("tokenizing prompts...")
|
||||||
|
|
||||||
|
val tokenizePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||||
|
modelName = model.name
|
||||||
|
input = chosenPrompt
|
||||||
|
}.build())
|
||||||
|
val tokenizeNegativePromptResponse = client.tokenizerServiceBlocking.tokenize(TokenizeRequest.newBuilder().apply {
|
||||||
|
modelName = model.name
|
||||||
|
input = chosenNegativePrompt
|
||||||
|
}.build())
|
||||||
|
|
||||||
|
println("tokenize prompt='${chosenPrompt}' " +
|
||||||
|
"tokens=[${tokenizePromptResponse.tokensList.joinToString(", ")}] " +
|
||||||
|
"token_ids=[${tokenizePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||||
|
|
||||||
|
println("tokenize negative_prompt='${chosenNegativePrompt}' " +
|
||||||
|
"tokens=[${tokenizeNegativePromptResponse.tokensList.joinToString(", ")}] " +
|
||||||
|
"token_ids=[${tokenizeNegativePromptResponse.tokenIdsList.joinToString(", ")}]")
|
||||||
|
|
||||||
println("generating images...")
|
println("generating images...")
|
||||||
|
|
||||||
val startingImagePath = Path("work/start.png")
|
val startingImagePath = Path("work/start.png")
|
||||||
|
|
||||||
val request = GenerateImagesRequest.newBuilder().apply {
|
val request = GenerateImagesRequest.newBuilder().apply {
|
||||||
modelName = model.name
|
modelName = model.name
|
||||||
outputImageFormat = StableDiffusion.ImageFormat.png
|
outputImageFormat = ImageFormat.png
|
||||||
batchSize = 2
|
batchSize = 2
|
||||||
batchCount = 2
|
batchCount = 2
|
||||||
prompt = "cat"
|
prompt = chosenPrompt
|
||||||
negativePrompt = "bad, low quality, nsfw"
|
negativePrompt = chosenNegativePrompt
|
||||||
if (startingImagePath.exists()) {
|
if (startingImagePath.exists()) {
|
||||||
val image = Image.newBuilder().apply {
|
val image = Image.newBuilder().apply {
|
||||||
format = StableDiffusion.ImageFormat.png
|
format = ImageFormat.png
|
||||||
data = ByteString.copyFrom(startingImagePath.readBytes())
|
data = ByteString.copyFrom(startingImagePath.readBytes())
|
||||||
}.build()
|
}.build()
|
||||||
|
|
||||||
@ -65,10 +89,12 @@ fun main() {
|
|||||||
}.build()
|
}.build()
|
||||||
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
for ((updateIndex, update) in client.imageGenerationServiceBlocking.generateImagesStreaming(request).withIndex()) {
|
||||||
if (update.hasBatchProgress()) {
|
if (update.hasBatchProgress()) {
|
||||||
println("batch ${update.currentBatch} progress ${update.batchProgress.percentageComplete}%")
|
println("batch=${update.currentBatch} " +
|
||||||
|
"progress=${prettyProgressValue(update.batchProgress.percentageComplete)}% " +
|
||||||
|
"overall=${prettyProgressValue(update.overallPercentageComplete)}%")
|
||||||
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
for ((index, image) in update.batchProgress.imagesList.withIndex()) {
|
||||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
println("image $imageIndex update $updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
println("image=$imageIndex update=$updateIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
val path = Path("work/intermediate_${imageIndex}_${updateIndex}.${image.format.name}")
|
||||||
path.writeBytes(image.data.toByteArray())
|
path.writeBytes(image.data.toByteArray())
|
||||||
}
|
}
|
||||||
@ -77,13 +103,14 @@ fun main() {
|
|||||||
if (update.hasBatchCompleted()) {
|
if (update.hasBatchCompleted()) {
|
||||||
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
for ((index, image) in update.batchCompleted.imagesList.withIndex()) {
|
||||||
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
val imageIndex = ((update.currentBatch - 1) * request.batchSize) + (index + 1)
|
||||||
println("image $imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
println("image=$imageIndex format=${image.format.name} data=(${image.data.size()} bytes)")
|
||||||
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
val path = Path("work/final_${imageIndex}.${image.format.name}")
|
||||||
path.writeBytes(image.data.toByteArray())
|
path.writeBytes(image.data.toByteArray())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println("overall progress ${update.overallPercentageComplete}%")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.shutdownNow()
|
channel.shutdownNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun prettyProgressValue(value: Float) = String.format("%.2f", value)
|
||||||
|
@ -2,7 +2,7 @@ package gay.pizza.stable.diffusion
|
|||||||
|
|
||||||
import io.grpc.Channel
|
import io.grpc.Channel
|
||||||
|
|
||||||
@Suppress("MemberVisibilityCanBePrivate")
|
@Suppress("MemberVisibilityCanBePrivate", "unused")
|
||||||
class StableDiffusionRpcClient(val channel: Channel) {
|
class StableDiffusionRpcClient(val channel: Channel) {
|
||||||
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
|
val modelService: ModelServiceGrpc.ModelServiceStub by lazy {
|
||||||
ModelServiceGrpc.newStub(channel)
|
ModelServiceGrpc.newStub(channel)
|
||||||
@ -35,4 +35,20 @@ class StableDiffusionRpcClient(val channel: Channel) {
|
|||||||
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
|
val imageGenerationServiceCoroutine: ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub by lazy {
|
||||||
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
|
ImageGenerationServiceGrpcKt.ImageGenerationServiceCoroutineStub(channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val tokenizerService: TokenizerServiceGrpc.TokenizerServiceStub by lazy {
|
||||||
|
TokenizerServiceGrpc.newStub(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
val tokenizerServiceBlocking: TokenizerServiceGrpc.TokenizerServiceBlockingStub by lazy {
|
||||||
|
TokenizerServiceGrpc.newBlockingStub(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
val tokenizerServiceFuture: TokenizerServiceGrpc.TokenizerServiceFutureStub by lazy {
|
||||||
|
TokenizerServiceGrpc.newFutureStub(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
val tokenizerServiceCoroutine: TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub by lazy {
|
||||||
|
TokenizerServiceGrpcKt.TokenizerServiceCoroutineStub(channel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -350,3 +350,43 @@ service ImageGenerationService {
|
|||||||
*/
|
*/
|
||||||
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
|
rpc GenerateImagesStreaming(GenerateImagesRequest) returns (stream GenerateImagesStreamUpdate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a request to tokenize an input.
|
||||||
|
*/
|
||||||
|
message TokenizeRequest {
|
||||||
|
/**
|
||||||
|
* The name of a loaded model to use for tokenization.
|
||||||
|
*/
|
||||||
|
string model_name = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The input string to tokenize.
|
||||||
|
*/
|
||||||
|
string input = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a response to tokenization.
|
||||||
|
*/
|
||||||
|
message TokenizeResponse {
|
||||||
|
/**
|
||||||
|
* The tokens inside the input string.
|
||||||
|
*/
|
||||||
|
repeated string tokens = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The token IDs inside the input string.
|
||||||
|
*/
|
||||||
|
repeated uint64 token_ids = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tokenizer service, for analyzing tokens for a loaded model.
|
||||||
|
*/
|
||||||
|
service TokenizerService {
|
||||||
|
/**
|
||||||
|
* Analyze the input using a loaded model and return the results.
|
||||||
|
*/
|
||||||
|
rpc Tokenize(TokenizeRequest) returns (TokenizeResponse);
|
||||||
|
}
|
||||||
|
@ -154,4 +154,15 @@ public actor ModelState {
|
|||||||
}
|
}
|
||||||
return pipelineConfig
|
return pipelineConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public func tokenize(_ request: SdTokenizeRequest) throws -> SdTokenizeResponse {
|
||||||
|
guard let tokenizer else {
|
||||||
|
throw SdCoreError.modelNotLoaded
|
||||||
|
}
|
||||||
|
let results = tokenizer.tokenize(input: request.input)
|
||||||
|
var response = SdTokenizeResponse()
|
||||||
|
response.tokens = results.tokens
|
||||||
|
response.tokenIds = results.tokenIDs.map { UInt64($0) }
|
||||||
|
return response
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -543,6 +543,199 @@ public enum SdImageGenerationServiceClientMetadata {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||||
|
///
|
||||||
|
/// Usage: instantiate `SdTokenizerServiceClient`, then call methods of this protocol to make API calls.
|
||||||
|
public protocol SdTokenizerServiceClientProtocol: GRPCClient {
|
||||||
|
var serviceName: String { get }
|
||||||
|
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
func tokenize(
|
||||||
|
_ request: SdTokenizeRequest,
|
||||||
|
callOptions: CallOptions?
|
||||||
|
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||||
|
}
|
||||||
|
|
||||||
|
extension SdTokenizerServiceClientProtocol {
|
||||||
|
public var serviceName: String {
|
||||||
|
return "gay.pizza.stable.diffusion.TokenizerService"
|
||||||
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Analyze the input using a loaded model and return the results.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - request: Request to send to Tokenize.
|
||||||
|
/// - callOptions: Call options.
|
||||||
|
/// - Returns: A `UnaryCall` with futures for the metadata, status and response.
|
||||||
|
public func tokenize(
|
||||||
|
_ request: SdTokenizeRequest,
|
||||||
|
callOptions: CallOptions? = nil
|
||||||
|
) -> UnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||||
|
return self.makeUnaryCall(
|
||||||
|
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||||
|
request: request,
|
||||||
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
|
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if compiler(>=5.6)
|
||||||
|
@available(*, deprecated)
|
||||||
|
extension SdTokenizerServiceClient: @unchecked Sendable {}
|
||||||
|
#endif // compiler(>=5.6)
|
||||||
|
|
||||||
|
@available(*, deprecated, renamed: "SdTokenizerServiceNIOClient")
|
||||||
|
public final class SdTokenizerServiceClient: SdTokenizerServiceClientProtocol {
|
||||||
|
private let lock = Lock()
|
||||||
|
private var _defaultCallOptions: CallOptions
|
||||||
|
private var _interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||||
|
public let channel: GRPCChannel
|
||||||
|
public var defaultCallOptions: CallOptions {
|
||||||
|
get { self.lock.withLock { return self._defaultCallOptions } }
|
||||||
|
set { self.lock.withLockVoid { self._defaultCallOptions = newValue } }
|
||||||
|
}
|
||||||
|
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||||
|
get { self.lock.withLock { return self._interceptors } }
|
||||||
|
set { self.lock.withLockVoid { self._interceptors = newValue } }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - channel: `GRPCChannel` to the service host.
|
||||||
|
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||||
|
/// - interceptors: A factory providing interceptors for each RPC.
|
||||||
|
public init(
|
||||||
|
channel: GRPCChannel,
|
||||||
|
defaultCallOptions: CallOptions = CallOptions(),
|
||||||
|
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||||
|
) {
|
||||||
|
self.channel = channel
|
||||||
|
self._defaultCallOptions = defaultCallOptions
|
||||||
|
self._interceptors = interceptors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct SdTokenizerServiceNIOClient: SdTokenizerServiceClientProtocol {
|
||||||
|
public var channel: GRPCChannel
|
||||||
|
public var defaultCallOptions: CallOptions
|
||||||
|
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||||
|
|
||||||
|
/// Creates a client for the gay.pizza.stable.diffusion.TokenizerService service.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - channel: `GRPCChannel` to the service host.
|
||||||
|
/// - defaultCallOptions: Options to use for each service call if the user doesn't provide them.
|
||||||
|
/// - interceptors: A factory providing interceptors for each RPC.
|
||||||
|
public init(
|
||||||
|
channel: GRPCChannel,
|
||||||
|
defaultCallOptions: CallOptions = CallOptions(),
|
||||||
|
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||||
|
) {
|
||||||
|
self.channel = channel
|
||||||
|
self.defaultCallOptions = defaultCallOptions
|
||||||
|
self.interceptors = interceptors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if compiler(>=5.6)
|
||||||
|
///*
|
||||||
|
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
public protocol SdTokenizerServiceAsyncClientProtocol: GRPCClient {
|
||||||
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
|
var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
func makeTokenizeCall(
|
||||||
|
_ request: SdTokenizeRequest,
|
||||||
|
callOptions: CallOptions?
|
||||||
|
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse>
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
extension SdTokenizerServiceAsyncClientProtocol {
|
||||||
|
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||||
|
return SdTokenizerServiceClientMetadata.serviceDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
public func makeTokenizeCall(
|
||||||
|
_ request: SdTokenizeRequest,
|
||||||
|
callOptions: CallOptions? = nil
|
||||||
|
) -> GRPCAsyncUnaryCall<SdTokenizeRequest, SdTokenizeResponse> {
|
||||||
|
return self.makeAsyncUnaryCall(
|
||||||
|
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||||
|
request: request,
|
||||||
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
|
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
extension SdTokenizerServiceAsyncClientProtocol {
|
||||||
|
public func tokenize(
|
||||||
|
_ request: SdTokenizeRequest,
|
||||||
|
callOptions: CallOptions? = nil
|
||||||
|
) async throws -> SdTokenizeResponse {
|
||||||
|
return try await self.performAsyncUnaryCall(
|
||||||
|
path: SdTokenizerServiceClientMetadata.Methods.tokenize.path,
|
||||||
|
request: request,
|
||||||
|
callOptions: callOptions ?? self.defaultCallOptions,
|
||||||
|
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? []
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
public struct SdTokenizerServiceAsyncClient: SdTokenizerServiceAsyncClientProtocol {
|
||||||
|
public var channel: GRPCChannel
|
||||||
|
public var defaultCallOptions: CallOptions
|
||||||
|
public var interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol?
|
||||||
|
|
||||||
|
public init(
|
||||||
|
channel: GRPCChannel,
|
||||||
|
defaultCallOptions: CallOptions = CallOptions(),
|
||||||
|
interceptors: SdTokenizerServiceClientInterceptorFactoryProtocol? = nil
|
||||||
|
) {
|
||||||
|
self.channel = channel
|
||||||
|
self.defaultCallOptions = defaultCallOptions
|
||||||
|
self.interceptors = interceptors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // compiler(>=5.6)
|
||||||
|
|
||||||
|
public protocol SdTokenizerServiceClientInterceptorFactoryProtocol: GRPCSendable {
|
||||||
|
|
||||||
|
/// - Returns: Interceptors to use when invoking 'tokenize'.
|
||||||
|
func makeTokenizeInterceptors() -> [ClientInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum SdTokenizerServiceClientMetadata {
|
||||||
|
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||||
|
name: "TokenizerService",
|
||||||
|
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||||
|
methods: [
|
||||||
|
SdTokenizerServiceClientMetadata.Methods.tokenize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
public enum Methods {
|
||||||
|
public static let tokenize = GRPCMethodDescriptor(
|
||||||
|
name: "Tokenize",
|
||||||
|
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||||
|
type: GRPCCallType.unary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///*
|
///*
|
||||||
/// The model service, for management and loading of models.
|
/// The model service, for management and loading of models.
|
||||||
///
|
///
|
||||||
@ -862,3 +1055,121 @@ public enum SdImageGenerationServiceServerMetadata {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
///*
|
||||||
|
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||||
|
///
|
||||||
|
/// To build a server, implement a class that conforms to this protocol.
|
||||||
|
public protocol SdTokenizerServiceProvider: CallHandlerProvider {
|
||||||
|
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Analyze the input using a loaded model and return the results.
|
||||||
|
func tokenize(request: SdTokenizeRequest, context: StatusOnlyCallContext) -> EventLoopFuture<SdTokenizeResponse>
|
||||||
|
}
|
||||||
|
|
||||||
|
extension SdTokenizerServiceProvider {
|
||||||
|
public var serviceName: Substring {
|
||||||
|
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines, calls and returns the appropriate request handler, depending on the request's method.
|
||||||
|
/// Returns nil for methods not handled by this service.
|
||||||
|
public func handle(
|
||||||
|
method name: Substring,
|
||||||
|
context: CallHandlerContext
|
||||||
|
) -> GRPCServerHandlerProtocol? {
|
||||||
|
switch name {
|
||||||
|
case "Tokenize":
|
||||||
|
return UnaryServerHandler(
|
||||||
|
context: context,
|
||||||
|
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||||
|
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||||
|
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||||
|
userFunction: self.tokenize(request:context:)
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if compiler(>=5.6)
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The tokenizer service, for analyzing tokens for a loaded model.
|
||||||
|
///
|
||||||
|
/// To implement a server, implement an object which conforms to this protocol.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
public protocol SdTokenizerServiceAsyncProvider: CallHandlerProvider {
|
||||||
|
static var serviceDescriptor: GRPCServiceDescriptor { get }
|
||||||
|
var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? { get }
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Analyze the input using a loaded model and return the results.
|
||||||
|
@Sendable func tokenize(
|
||||||
|
request: SdTokenizeRequest,
|
||||||
|
context: GRPCAsyncServerCallContext
|
||||||
|
) async throws -> SdTokenizeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
extension SdTokenizerServiceAsyncProvider {
|
||||||
|
public static var serviceDescriptor: GRPCServiceDescriptor {
|
||||||
|
return SdTokenizerServiceServerMetadata.serviceDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
public var serviceName: Substring {
|
||||||
|
return SdTokenizerServiceServerMetadata.serviceDescriptor.fullName[...]
|
||||||
|
}
|
||||||
|
|
||||||
|
public var interceptors: SdTokenizerServiceServerInterceptorFactoryProtocol? {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
public func handle(
|
||||||
|
method name: Substring,
|
||||||
|
context: CallHandlerContext
|
||||||
|
) -> GRPCServerHandlerProtocol? {
|
||||||
|
switch name {
|
||||||
|
case "Tokenize":
|
||||||
|
return GRPCAsyncServerHandler(
|
||||||
|
context: context,
|
||||||
|
requestDeserializer: ProtobufDeserializer<SdTokenizeRequest>(),
|
||||||
|
responseSerializer: ProtobufSerializer<SdTokenizeResponse>(),
|
||||||
|
interceptors: self.interceptors?.makeTokenizeInterceptors() ?? [],
|
||||||
|
wrapping: self.tokenize(request:context:)
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // compiler(>=5.6)
|
||||||
|
|
||||||
|
public protocol SdTokenizerServiceServerInterceptorFactoryProtocol {
|
||||||
|
|
||||||
|
/// - Returns: Interceptors to use when handling 'tokenize'.
|
||||||
|
/// Defaults to calling `self.makeInterceptors()`.
|
||||||
|
func makeTokenizeInterceptors() -> [ServerInterceptor<SdTokenizeRequest, SdTokenizeResponse>]
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum SdTokenizerServiceServerMetadata {
|
||||||
|
public static let serviceDescriptor = GRPCServiceDescriptor(
|
||||||
|
name: "TokenizerService",
|
||||||
|
fullName: "gay.pizza.stable.diffusion.TokenizerService",
|
||||||
|
methods: [
|
||||||
|
SdTokenizerServiceServerMetadata.Methods.tokenize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
public enum Methods {
|
||||||
|
public static let tokenize = GRPCMethodDescriptor(
|
||||||
|
name: "Tokenize",
|
||||||
|
path: "/gay.pizza.stable.diffusion.TokenizerService/Tokenize",
|
||||||
|
type: GRPCCallType.unary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -562,6 +562,46 @@ public struct SdGenerateImagesStreamUpdate {
|
|||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a request to tokenize an input.
|
||||||
|
public struct SdTokenizeRequest {
|
||||||
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The name of a loaded model to use for tokenization.
|
||||||
|
public var modelName: String = String()
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The input string to tokenize.
|
||||||
|
public var input: String = String()
|
||||||
|
|
||||||
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
|
public init() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// Represents a response to tokenization.
|
||||||
|
public struct SdTokenizeResponse {
|
||||||
|
// SwiftProtobuf.Message conformance is added in an extension below. See the
|
||||||
|
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
|
||||||
|
// methods supported on all messages.
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The tokens inside the input string.
|
||||||
|
public var tokens: [String] = []
|
||||||
|
|
||||||
|
///*
|
||||||
|
/// The token IDs inside the input string.
|
||||||
|
public var tokenIds: [UInt64] = []
|
||||||
|
|
||||||
|
public var unknownFields = SwiftProtobuf.UnknownStorage()
|
||||||
|
|
||||||
|
public init() {}
|
||||||
|
}
|
||||||
|
|
||||||
#if swift(>=5.5) && canImport(_Concurrency)
|
#if swift(>=5.5) && canImport(_Concurrency)
|
||||||
extension SdModelAttention: @unchecked Sendable {}
|
extension SdModelAttention: @unchecked Sendable {}
|
||||||
extension SdScheduler: @unchecked Sendable {}
|
extension SdScheduler: @unchecked Sendable {}
|
||||||
@ -579,6 +619,8 @@ extension SdGenerateImagesBatchProgressUpdate: @unchecked Sendable {}
|
|||||||
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
|
extension SdGenerateImagesBatchCompletedUpdate: @unchecked Sendable {}
|
||||||
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
|
extension SdGenerateImagesStreamUpdate: @unchecked Sendable {}
|
||||||
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
|
extension SdGenerateImagesStreamUpdate.OneOf_Update: @unchecked Sendable {}
|
||||||
|
extension SdTokenizeRequest: @unchecked Sendable {}
|
||||||
|
extension SdTokenizeResponse: @unchecked Sendable {}
|
||||||
#endif // swift(>=5.5) && canImport(_Concurrency)
|
#endif // swift(>=5.5) && canImport(_Concurrency)
|
||||||
|
|
||||||
// MARK: - Code below here is support for the SwiftProtobuf runtime.
|
// MARK: - Code below here is support for the SwiftProtobuf runtime.
|
||||||
@ -1125,3 +1167,79 @@ extension SdGenerateImagesStreamUpdate: SwiftProtobuf.Message, SwiftProtobuf._Me
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extension SdTokenizeRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||||
|
public static let protoMessageName: String = _protobuf_package + ".TokenizeRequest"
|
||||||
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
|
1: .standard(proto: "model_name"),
|
||||||
|
2: .same(proto: "input"),
|
||||||
|
]
|
||||||
|
|
||||||
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
|
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||||
|
// The use of inline closures is to circumvent an issue where the compiler
|
||||||
|
// allocates stack space for every case branch when no optimizations are
|
||||||
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
|
switch fieldNumber {
|
||||||
|
case 1: try { try decoder.decodeSingularStringField(value: &self.modelName) }()
|
||||||
|
case 2: try { try decoder.decodeSingularStringField(value: &self.input) }()
|
||||||
|
default: break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||||
|
if !self.modelName.isEmpty {
|
||||||
|
try visitor.visitSingularStringField(value: self.modelName, fieldNumber: 1)
|
||||||
|
}
|
||||||
|
if !self.input.isEmpty {
|
||||||
|
try visitor.visitSingularStringField(value: self.input, fieldNumber: 2)
|
||||||
|
}
|
||||||
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
public static func ==(lhs: SdTokenizeRequest, rhs: SdTokenizeRequest) -> Bool {
|
||||||
|
if lhs.modelName != rhs.modelName {return false}
|
||||||
|
if lhs.input != rhs.input {return false}
|
||||||
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension SdTokenizeResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
|
||||||
|
public static let protoMessageName: String = _protobuf_package + ".TokenizeResponse"
|
||||||
|
public static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
|
||||||
|
1: .same(proto: "tokens"),
|
||||||
|
2: .standard(proto: "token_ids"),
|
||||||
|
]
|
||||||
|
|
||||||
|
public mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
|
||||||
|
while let fieldNumber = try decoder.nextFieldNumber() {
|
||||||
|
// The use of inline closures is to circumvent an issue where the compiler
|
||||||
|
// allocates stack space for every case branch when no optimizations are
|
||||||
|
// enabled. https://github.com/apple/swift-protobuf/issues/1034
|
||||||
|
switch fieldNumber {
|
||||||
|
case 1: try { try decoder.decodeRepeatedStringField(value: &self.tokens) }()
|
||||||
|
case 2: try { try decoder.decodeRepeatedUInt64Field(value: &self.tokenIds) }()
|
||||||
|
default: break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
|
||||||
|
if !self.tokens.isEmpty {
|
||||||
|
try visitor.visitRepeatedStringField(value: self.tokens, fieldNumber: 1)
|
||||||
|
}
|
||||||
|
if !self.tokenIds.isEmpty {
|
||||||
|
try visitor.visitPackedUInt64Field(value: self.tokenIds, fieldNumber: 2)
|
||||||
|
}
|
||||||
|
try unknownFields.traverse(visitor: &visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
public static func ==(lhs: SdTokenizeResponse, rhs: SdTokenizeResponse) -> Bool {
|
||||||
|
if lhs.tokens != rhs.tokens {return false}
|
||||||
|
if lhs.tokenIds != rhs.tokenIds {return false}
|
||||||
|
if lhs.unknownFields != rhs.unknownFields {return false}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
19
Sources/StableDiffusionServer/TokenizerService.swift
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import Foundation
|
||||||
|
import GRPC
|
||||||
|
import StableDiffusionCore
|
||||||
|
import StableDiffusionProtos
|
||||||
|
|
||||||
|
class TokenizerServiceProvider: SdTokenizerServiceAsyncProvider {
|
||||||
|
private let modelManager: ModelManager
|
||||||
|
|
||||||
|
init(modelManager: ModelManager) {
|
||||||
|
self.modelManager = modelManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenize(request: SdTokenizeRequest, context _: GRPCAsyncServerCallContext) async throws -> SdTokenizeResponse {
|
||||||
|
guard let state = await modelManager.getModelState(name: request.modelName) else {
|
||||||
|
throw SdCoreError.modelNotFound
|
||||||
|
}
|
||||||
|
return try await state.tokenize(request)
|
||||||
|
}
|
||||||
|
}
|
@ -34,7 +34,8 @@ struct ServerCommand: ParsableCommand {
|
|||||||
_ = Server.insecure(group: group)
|
_ = Server.insecure(group: group)
|
||||||
.withServiceProviders([
|
.withServiceProviders([
|
||||||
ModelServiceProvider(modelManager: modelManager),
|
ModelServiceProvider(modelManager: modelManager),
|
||||||
ImageGenerationServiceProvider(modelManager: modelManager)
|
ImageGenerationServiceProvider(modelManager: modelManager),
|
||||||
|
TokenizerServiceProvider(modelManager: modelManager)
|
||||||
])
|
])
|
||||||
.bind(host: bindHost, port: bindPort)
|
.bind(host: bindHost, port: bindPort)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user