Skip to content

Commit

Permalink
dynamically determine available checkpoints, some workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnrushefsky committed Aug 16, 2024
1 parent fe3867f commit 0cfa238
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 12 deletions.
15 changes: 14 additions & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const {
STARTUP_CHECK_MAX_TRIES = "10",
OUTPUT_DIR = "/opt/ComfyUI/output",
INPUT_DIR = "/opt/ComfyUI/input",
CKPT_DIR = "/opt/ComfyUI/models/checkpoints",
WARMUP_PROMPT_FILE,
} = process.env;

Expand All @@ -18,18 +19,28 @@ const port = parseInt(PORT, 10);
const startupCheckInterval = parseInt(STARTUP_CHECK_INTERVAL_S, 10) * 1000;
const startupCheckMaxTries = parseInt(STARTUP_CHECK_MAX_TRIES, 10);

let warmupPrompt: string | undefined;
let warmupPrompt: any | undefined;
let warmupCkpt: string | undefined;
if (WARMUP_PROMPT_FILE) {
assert(fs.existsSync(WARMUP_PROMPT_FILE), "Warmup prompt file not found");
try {
warmupPrompt = JSON.parse(
fs.readFileSync(WARMUP_PROMPT_FILE, { encoding: "utf-8" })
);
for (const nodeId in warmupPrompt) {
const node = warmupPrompt[nodeId];
if (node.class_type === "CheckpointLoaderSimple") {
warmupCkpt = node.inputs.ckpt_name;
break;
}
}
} catch (e: any) {
throw new Error(`Failed to parse warmup prompt: ${e.message}`);
}
}

const allCheckpoints = fs.readdirSync(CKPT_DIR);

const config = {
comfyLaunchCmd: CMD,
wrapperHost: HOST,
Expand All @@ -42,6 +53,8 @@ const config = {
outputDir: OUTPUT_DIR,
inputDir: INPUT_DIR,
warmupPrompt,
warmupCkpt,
checkpoints: allCheckpoints,
};

export default config;
9 changes: 8 additions & 1 deletion src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,14 @@ server.after(() => {
body: JSON.stringify({ prompt, id, webhook }),
}
);
return reply.code(resp.status).send(await resp.json());
const body = await resp.json();
if (!resp.ok) {
return reply.code(resp.status).send(body);
}

body.prompt = prompt;

return reply.code(resp.status).send(body);
}
);
}
Expand Down
11 changes: 8 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { z } from "zod";
import { string, z } from "zod";
import { randomUUID } from "crypto";
import config from "./config";

export const ComfyNodeSchema = z.object({
inputs: z.any(),
Expand Down Expand Up @@ -42,10 +43,10 @@ export const WorkflowSchema = z.object({
generateWorkflow: z.function(),
});

export type Workflow = {
export interface Workflow {
RequestSchema: z.ZodObject<any, any>;
generateWorkflow: (input: any) => Record<string, ComfyNode>;
};
}

export const WorkflowRequestSchema = z.object({
id: z
Expand All @@ -57,3 +58,7 @@ export const WorkflowRequestSchema = z.object({
});

export type WorkflowRequest = z.infer<typeof WorkflowRequestSchema>;

export const AvailableCheckpoints = z.enum(
config.checkpoints as unknown as readonly [string, ...string[]]
);
153 changes: 153 additions & 0 deletions src/workflows/flux/img2img.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import { z } from "zod";
import { ComfyNode, Workflow, AvailableCheckpoints } from "../../types";
import config from "../../config";

let checkpoint: any = AvailableCheckpoints.optional();
if (config.warmupCkpt) {
checkpoint = AvailableCheckpoints.default(config.warmupCkpt);
}

const RequestSchema = z.object({
prompt: z.string(),
width: z.number().int().min(256).max(1024).optional().default(1024),
height: z.number().int().min(256).max(1024).optional().default(1024),
seed: z
.number()
.int()
.optional()
.default(() => Math.floor(Math.random() * 1000000000000000)),
steps: z.number().int().min(1).max(10).optional().default(2),
sampler: z.enum(["euler"]).optional().default("euler"),
scheduler: z.enum(["simple"]).optional().default("simple"),
denoise: z.number().min(0).max(1).optional().default(0.8),
cfg: z.number().min(1).max(30).optional().default(1),
image: z.string(),
checkpoint,
});

type InputType = z.infer<typeof RequestSchema>;

function generateWorkflow(input: InputType): Record<string, ComfyNode> {
return {
"6": {
inputs: {
text: input.prompt,
clip: ["30", 1],
},
class_type: "CLIPTextEncode",
_meta: {
title: "CLIP Text Encode (Positive Prompt)",
},
},
"8": {
inputs: {
samples: ["31", 0],
vae: ["30", 2],
},
class_type: "VAEDecode",
_meta: {
title: "VAE Decode",
},
},
"9": {
inputs: {
filename_prefix: "",
images: ["8", 0],
},
class_type: "SaveImage",
_meta: {
title: "Save Image",
},
},
"27": {
inputs: {
width: input.width,
height: input.height,
batch_size: 1,
},
class_type: "EmptySD3LatentImage",
_meta: {
title: "EmptySD3LatentImage",
},
},
"30": {
inputs: {
ckpt_name: input.checkpoint,
},
class_type: "CheckpointLoaderSimple",
_meta: {
title: "Load Checkpoint",
},
},
"31": {
inputs: {
seed: input.seed,
steps: input.steps,
cfg: input.cfg,
sampler_name: input.sampler,
scheduler: input.scheduler,
denoise: input.denoise,
model: ["30", 0],
positive: ["6", 0],
negative: ["33", 0],
latent_image: ["38", 0],
},
class_type: "KSampler",
_meta: {
title: "KSampler",
},
},
"33": {
inputs: {
text: "",
clip: ["30", 1],
},
class_type: "CLIPTextEncode",
_meta: {
title: "CLIP Text Encode (Negative Prompt)",
},
},
"37": {
inputs: {
image: input.image,
upload: "image",
},
class_type: "LoadImage",
_meta: {
title: "Load Image",
},
},
"38": {
inputs: {
pixels: ["40", 0],
vae: ["30", 2],
},
class_type: "VAEEncode",
_meta: {
title: "VAE Encode",
},
},
"40": {
inputs: {
width: input.width,
height: input.height,
interpolation: "nearest",
method: "fill / crop",
condition: "always",
multiple_of: 8,
image: ["37", 0],
},
class_type: "ImageResize+",
_meta: {
title: "🔧 Image Resize",
},
},
};
}

const workflow: Workflow = {
RequestSchema,
generateWorkflow,
};

export default workflow;
17 changes: 10 additions & 7 deletions src/workflows/flux/txt2img.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { z } from "zod";
import { ComfyNode, Workflow } from "../../types";
import { ComfyNode, Workflow, AvailableCheckpoints } from "../../types";
import config from "../../config";

let checkpoint: any = AvailableCheckpoints.optional();
if (config.warmupCkpt) {
checkpoint = AvailableCheckpoints.default(config.warmupCkpt);
}

const RequestSchema = z.object({
prompt: z.string(),
Expand All @@ -13,15 +19,12 @@ const RequestSchema = z.object({
steps: z.number().int().min(1).max(10).optional().default(4),
sampler: z.enum(["euler"]).optional().default("euler"), // This may need to be expanded with more options
scheduler: z.enum(["simple"]).optional().default("simple"), // This may need to be expanded with more options
checkpoint: z
.enum(["flux1-schnell-fp8.safetensors"])
.optional()
.default("flux1-schnell-fp8.safetensors"), // This may need to be expanded with more options
checkpoint,
});

type Input = z.infer<typeof RequestSchema>;
type InputType = z.infer<typeof RequestSchema>;

export function generateWorkflow(input: Input): Record<string, ComfyNode> {
function generateWorkflow(input: InputType): Record<string, ComfyNode> {
return {
"6": {
inputs: {
Expand Down
6 changes: 6 additions & 0 deletions src/workflows/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import fluxTxt2img from "../workflows/flux/txt2img";
import fluxImg2img from "../workflows/flux/img2img";
import sd15Txt2img from "../workflows/sd1.5/txt2img";

export const workflows: any = {
flux: {
txt2img: fluxTxt2img,
img2img: fluxImg2img,
},
"sd1.5": {
txt2img: sd15Txt2img,
},
};
Loading

0 comments on commit 0cfa238

Please sign in to comment.