-
Notifications
You must be signed in to change notification settings - Fork 4
/
helpers.ts
343 lines (291 loc) · 11.7 KB
/
helpers.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import * as Generation from "./generation/generation_pb";
import { GenerationServiceClient } from "./generation/generation_pb_service";
import { grpc as GRPCWeb } from "@improbable-eng/grpc-web";
import fs from "fs";
import { ArtifactTypeMap, FinishReasonMap } from "./generation/generation_pb";
export type GenerationTextPrompt = {
/** The text prompt, maximum of 2000 characters. */
text: string;
/** The weight of the prompt, use negative values for negative prompts. */
weight?: number;
};
export type CommonGenerationParams = {
prompts: GenerationTextPrompt[];
samples?: number;
steps?: number;
cfgScale?: number;
sampler?: Generation.DiffusionSamplerMap[keyof Generation.DiffusionSamplerMap];
clipGuidancePreset?: Generation.GuidancePresetMap[keyof Generation.GuidancePresetMap];
seed?: number;
};
export type TextToImageParams = CommonGenerationParams & {
type: "text-to-image";
height?: number;
width?: number;
};
export type ImageToImageParams = CommonGenerationParams & {
type: "image-to-image";
initImage: Buffer;
stepScheduleStart: number;
stepScheduleEnd?: number;
};
export type ImageToImageMaskingParams = CommonGenerationParams & {
type: "image-to-image-masking";
initImage: Buffer;
maskImage: Buffer;
};
export type UpscalingParams = HeightOrWidth & {
type: "upscaling";
initImage: Buffer;
upscaler: Generation.UpscalerMap[keyof Generation.UpscalerMap];
};
type HeightOrWidth =
| { height: number; width?: never }
| { height?: never; width: number }
| { height?: never; width?: never };
export type GenerationRequestParams =
| TextToImageParams
| ImageToImageParams
| ImageToImageMaskingParams
| UpscalingParams;
export type GenerationRequest = Generation.Request;
export type GenerationResponse = GenerationArtifacts | Error;
export type GenerationArtifacts = {
/**
* Successfully generated artifacts whose binary content is available.
*/
imageArtifacts: Array<ImageArtifact>;
/**
* These artifacts were filtered due to the NSFW classifier. This classifier is imperfect and
* has frequent false-positives. You are not charged for blurred images and are welcome to retry.
*/
filteredArtifacts: Array<NSFWFilteredArtifact>;
};
export type ImageArtifact = Omit<
Generation.Artifact,
"hasBinary" | "getType" | "getFinishReason"
> & {
getType(): FinishReasonMap["NULL"];
getFinishReason(): ArtifactTypeMap["ARTIFACT_IMAGE"];
hasBinary(): true;
};
export const isImageArtifact = (
artifact: Generation.Artifact
): artifact is ImageArtifact =>
artifact.getType() === Generation.ArtifactType.ARTIFACT_IMAGE &&
artifact.getFinishReason() === Generation.FinishReason.NULL &&
artifact.hasBinary();
/** This represents an artifact whose content was blurred by the NSFW classifier. */
export type NSFWFilteredArtifact = Omit<
Generation.Artifact,
"getType" | "getFinishReason"
> & {
getType(): FinishReasonMap["FILTER"];
getFinishReason(): ArtifactTypeMap["ARTIFACT_IMAGE"];
};
export const isNSFWFilteredArtifact = (
artifact: Generation.Artifact
): artifact is NSFWFilteredArtifact =>
artifact.getType() === Generation.ArtifactType.ARTIFACT_IMAGE &&
artifact.getFinishReason() === Generation.FinishReason.FILTER;
/** Builds a generation request for a specified engine with the specified parameters. */
export function buildGenerationRequest(
engineID: string,
params: GenerationRequestParams
): GenerationRequest {
if (params.type === "upscaling") {
const request = new Generation.Request();
request.setEngineId(engineID);
request.setRequestedType(Generation.ArtifactType.ARTIFACT_IMAGE);
request.setClassifier(new Generation.ClassifierParameters());
const imageParams = new Generation.ImageParameters();
if ("width" in params && !!params.width) {
imageParams.setWidth(params.width);
} else if ("height" in params && !!params.height) {
imageParams.setHeight(params.height);
}
request.setImage(imageParams);
request.addPrompt(createInitImagePrompt(params.initImage));
return request;
}
const imageParams = new Generation.ImageParameters();
if (params.type === "text-to-image") {
params.width && imageParams.setWidth(params.width);
params.height && imageParams.setHeight(params.height);
}
// Set the number of images to generate (Default 1)
params.samples && imageParams.setSamples(params.samples);
// Set the steps (Default 30)
// Represents the amount of inference steps performed on image generation.
params.steps && imageParams.setSteps(params.steps);
// Set the seed (Default 0)
// Including a seed will cause the results to be deterministic.
// Omitting the seed or setting it to `0` will do the opposite.
params.seed && imageParams.addSeed(params.seed);
// Set the sampler (Default 'automatic')
// Omitting this value enables 'automatic' mode where we choose the best sampler for you based
// on the current payload. For example, since CLIP guidance only works on ancestral samplers,
// when CLIP guidance is enabled, we will automatically choose an ancestral sampler for you.
if (params.sampler) {
const transformType = new Generation.TransformType();
transformType.setDiffusion(params.sampler);
imageParams.setTransform(transformType);
}
// Set the Engine
// At the time of writing, valid engines are:
// stable-diffusion-v1,
// stable-diffusion-v1-5
// stable-diffusion-512-v2-0
// stable-diffusion-768-v2-0
// stable-diffusion-512-v2-1
// stable-diffusion-768-v2-1
// stable-inpainting-v1-0
// stable-inpainting-512-v2-0
// stable-diffusion-xl-beta-v2-2-2
// stable-diffusion-xl-1024-v0-9
// stable-diffusion-xl-1024-v1-0
// esrgan-v1-x2plus
const request = new Generation.Request();
request.setEngineId(engineID);
request.setRequestedType(Generation.ArtifactType.ARTIFACT_IMAGE);
request.setClassifier(new Generation.ClassifierParameters());
// Set the CFG scale (Default 7)
// Influences how strongly your generation is guided to match your prompt. Higher values match closer.
const samplerParams = new Generation.SamplerParameters();
params.cfgScale && samplerParams.setCfgScale(params.cfgScale);
const stepParams = new Generation.StepParameter();
stepParams.setScaledStep(0);
stepParams.setSampler(samplerParams);
const scheduleParams = new Generation.ScheduleParameters();
if (params.type === "image-to-image") {
// If we're doing image-to-image generation then we need to configure
// how much influence the initial image has on the diffusion process
scheduleParams.setStart(params.stepScheduleStart);
if (params.stepScheduleEnd) {
scheduleParams.setEnd(params.stepScheduleEnd);
}
} else if (params.type === "image-to-image-masking") {
// Step schedule start is always 1 for masking requests
scheduleParams.setStart(1);
}
stepParams.setSchedule(scheduleParams);
// Set CLIP Guidance (Default: None)
// NOTE: This only works with ancestral samplers. Omitting the sampler parameter above will ensure
// that we automatically choose an ancestral sampler for you when CLIP guidance is enabled.
if (params.clipGuidancePreset) {
const guidanceParameters = new Generation.GuidanceParameters();
guidanceParameters.setGuidancePreset(params.clipGuidancePreset);
stepParams.setGuidance(guidanceParameters);
}
imageParams.addParameters(stepParams);
request.setImage(imageParams);
params.prompts.forEach((textPrompt) => {
const prompt = new Generation.Prompt();
prompt.setText(textPrompt.text);
// If provided, set the prompt's weight (use negative values for negative weighting)
if (textPrompt.weight) {
const promptParameters = new Generation.PromptParameters();
promptParameters.setWeight(textPrompt.weight);
prompt.setParameters(promptParameters);
}
request.addPrompt(prompt);
});
// Add image prompts if we're doing some kind of image-to-image generation or upscaling
if (params.type === "image-to-image") {
request.addPrompt(createInitImagePrompt(params.initImage));
} else if (params.type === "image-to-image-masking") {
request.addPrompt(createInitImagePrompt(params.initImage));
request.addPrompt(createMaskImagePrompt(params.maskImage));
}
return request;
}
function createInitImagePrompt(imageBinary: Buffer): Generation.Prompt {
const initImageArtifact = new Generation.Artifact();
initImageArtifact.setBinary(imageBinary);
initImageArtifact.setType(Generation.ArtifactType.ARTIFACT_IMAGE);
const initImageParameters = new Generation.PromptParameters();
initImageParameters.setInit(true);
const initImagePrompt = new Generation.Prompt();
initImagePrompt.setParameters(initImageParameters);
initImagePrompt.setArtifact(initImageArtifact);
return initImagePrompt;
}
function createMaskImagePrompt(imageBinary: Buffer): Generation.Prompt {
const maskImageArtifact = new Generation.Artifact();
maskImageArtifact.setBinary(imageBinary);
maskImageArtifact.setType(Generation.ArtifactType.ARTIFACT_MASK);
const maskImagePrompt = new Generation.Prompt();
maskImagePrompt.setArtifact(maskImageArtifact);
return maskImagePrompt;
}
/** Executes a GenerationRequest, abstracting the gRPC streaming result behind a Promise */
export async function executeGenerationRequest(
generationClient: GenerationServiceClient,
request: GenerationRequest,
metadata: GRPCWeb.Metadata
): Promise<GenerationResponse> {
try {
const stream = generationClient.generate(request, metadata);
const answers = await new Promise<Generation.Answer[]>(
(resolve, reject) => {
const answers = new Array<Generation.Answer>();
stream.on("data", (data) => answers.push(data));
stream.on("end", () => resolve(answers));
stream.on("status", (status) => {
if (status.code === 0) return;
reject(status.details);
});
}
);
return extractArtifacts(answers);
} catch (err) {
return err instanceof Error ? err : new Error(JSON.stringify(err));
}
}
function extractArtifacts(answers: Generation.Answer[]): GenerationArtifacts {
const imageArtifacts = new Array<ImageArtifact>();
const filteredArtifacts = new Array<NSFWFilteredArtifact>();
for (const answer of answers) {
for (const artifact of answer.getArtifactsList()) {
if (isImageArtifact(artifact)) {
imageArtifacts.push(artifact);
} else if (isNSFWFilteredArtifact(artifact)) {
filteredArtifacts.push(artifact);
}
}
}
return { filteredArtifacts, imageArtifacts };
}
/** Generation completion handler - replace this with your own logic */
export function onGenerationComplete(response: GenerationResponse) {
if (response instanceof Error) {
console.error("Generation failed", response);
throw response;
}
console.log(
`${response.imageArtifacts.length} image${
response.imageArtifacts.length > 1 ? "s" : ""
} were successfully generated.`
);
// Do something with NSFW filtered artifacts
if (response.filteredArtifacts.length > 0) {
console.log(
`${response.filteredArtifacts.length} artifact` +
`${response.filteredArtifacts.length > 1 ? "s" : ""}` +
` were filtered by the NSFW classifier and need to be retried.`
);
}
// Do something with the successful image artifacts
response.imageArtifacts.forEach((artifact: Generation.Artifact) => {
try {
fs.writeFileSync(
`image-${artifact.getSeed()}.png`,
Buffer.from(artifact.getBinary_asU8())
);
} catch (error) {
console.error("Failed to write resulting image to disk", error);
}
});
// For browser implementations: you could use the `artifact.getBinary_asB64()` method to get a
// base64 encoded string and then create a data URL from that and display it in an <img> tag.
}