Skip to content

Commit

Permalink
feat: workflow executor (#7159)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Jun 25, 2024
1 parent 45b3b83 commit fe89ecb
Show file tree
Hide file tree
Showing 16 changed files with 573 additions and 201 deletions.
2 changes: 2 additions & 0 deletions packages/backend/server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@
"on-headers": "^1.0.2",
"openai": "^4.33.0",
"parse-duration": "^1.1.0",
"piscina": "^4.5.1",
"pretty-time": "^1.1.0",
"prisma": "^5.12.1",
"prom-client": "^15.1.1",
"reflect-metadata": "^0.2.2",
"rxjs": "^7.8.1",
"semver": "^7.6.0",
"ses": "^1.4.1",
"socket.io": "^4.7.5",
"stripe": "^15.0.0",
"ts-node": "^10.9.2",
Expand Down
4 changes: 2 additions & 2 deletions packages/backend/server/src/data/migrations/utils/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ content: {{content}}`,
],
},
{
name: 'workflow:presentation:step4',
action: 'workflow:presentation:step4',
name: 'workflow:presentation:step5',
action: 'workflow:presentation:step5',
model: 'gpt-4o',
messages: [
{
Expand Down
3 changes: 2 additions & 1 deletion packages/backend/server/src/plugins/copilot/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
} from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotWorkflowService } from './workflow';
import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';

registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
Expand All @@ -41,6 +41,7 @@ registerCopilotProvider(OpenAIProvider);
CopilotStorage,
PromptsManagementResolver,
CopilotWorkflowService,
...CopilotWorkflowExecutors,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { Injectable } from '@nestjs/common';

import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
import {
NodeData,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from '../types';
import { WorkflowExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';

@Injectable()
export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
constructor(
private readonly promptService: PromptService,
private readonly providerService: CopilotProviderService
) {
super();
}

private async initExecutor(
data: NodeData
): Promise<
[
NodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotTextProvider,
]
> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}

const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
);
}
const provider = await this.providerService.getProviderByModel(
prompt.model
);
if (provider && 'generateText' in provider) {
return [data, prompt, provider];
}

throw new Error(
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
);
}

override get type() {
return WorkflowExecutorType.ChatText;
}

override async *next(
data: NodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);

const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: WorkflowResultType.Params,
params: {
[paramKey]: await provider.generateText(
finalMessage,
prompt.model,
options
),
},
};
} else {
for await (const content of provider.generateTextStream(
finalMessage,
prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: id,
content,
};
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { CopilotChatTextExecutor } from './chat-text';

export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];

export { type WorkflowExecutor, WorkflowExecutorType } from './types';
export { getWorkflowExecutor } from './utils';
export { CopilotChatTextExecutor };
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { CopilotChatOptions } from '../../types';
import { NodeData, WorkflowResult } from '../types';

export enum WorkflowExecutorType {
ChatText = 'ChatText',
}

export abstract class WorkflowExecutor {
abstract get type(): WorkflowExecutorType;
abstract next(
data: NodeData,
params: Record<string, string | string[]>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult>;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { Logger, OnModuleInit } from '@nestjs/common';

import { WorkflowExecutor, type WorkflowExecutorType } from './types';

const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();

function registerWorkflowExecutor(e: WorkflowExecutor) {
const existing = WORKFLOW_EXECUTOR.get(e.type);
if (existing && existing === e) return false;
WORKFLOW_EXECUTOR.set(e.type, e);
return true;
}

export function getWorkflowExecutor(
type: WorkflowExecutorType
): WorkflowExecutor {
const executor = WORKFLOW_EXECUTOR.get(type);
if (!executor) {
throw new Error(`Executor ${type} not defined`);
}

return executor;
}

export abstract class AutoRegisteredWorkflowExecutor
extends WorkflowExecutor
implements OnModuleInit
{
onModuleInit() {
this.register();
}

register() {
if (registerWorkflowExecutor(this)) {
new Logger(`CopilotWorkflowExecutor:${this.type}`).log(
'Workflow executor registered.'
);
}
}
}
80 changes: 37 additions & 43 deletions packages/backend/server/src/plugins/copilot/workflow/graph.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import { type WorkflowGraphList, WorkflowNodeType } from './types';
import { WorkflowExecutorType } from './executor';
import type { WorkflowGraphs } from './types';
import { WorkflowNodeState, WorkflowNodeType } from './types';

export const WorkflowGraphs: WorkflowGraphList = [
export const WorkflowGraphList: WorkflowGraphs = [
{
name: 'presentation',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: 'text',
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
Expand All @@ -17,49 +19,41 @@ export const WorkflowGraphs: WorkflowGraphList = [
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: 'text',
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: check format',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step3',
paramKey: 'needFormat',
edges: ['step4'],
},
{
id: 'step4',
name: 'Step 4: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) =>
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
edges: ['step5', 'step6'],
},
{
id: 'step5',
name: 'Step 5: format presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step5',
edges: ['step6'],
},
{
id: 'step6',
name: 'Step 6: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
// edges: ['step3'],
},
// {
// id: 'step3',
// name: 'Step 3: check format',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step3',
// paramKey: 'needFormat',
// edges: ['step4'],
// },
// {
// id: 'step4',
// name: 'Step 4: format presentation if needed',
// nodeType: WorkflowNodeType.Decision,
// condition: ((
// nodeIds: string[],
// params: WorkflowNodeState
// ) =>
// nodeIds[
// Number(String(params.needFormat).toLowerCase() === 'true')
// ]).toString(),
// edges: ['step5', 'step6'],
// },
// {
// id: 'step5',
// name: 'Step 5: format presentation',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step5',
// edges: ['step6'],
// },
// {
// id: 'step6',
// name: 'Step 6: finish',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step6',
// edges: [],
// },
],
},
];
38 changes: 17 additions & 21 deletions packages/backend/server/src/plugins/copilot/workflow/index.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
import { Injectable, Logger } from '@nestjs/common';

import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphs } from './graph';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import { WorkflowGraph, WorkflowGraphList } from './types';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { CopilotWorkflow } from './workflow';

@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
constructor(
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService
) {}
constructor() {}

private initWorkflow({ name, graph }: WorkflowGraphList[number]) {
const workflow = new Map();
for (const nodeData of graph) {
private initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(data);
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}

// add edges
for (const nodeData of graph) {
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
Expand All @@ -47,9 +42,11 @@ export class CopilotWorkflowService {
return workflow;
}

// TODO(@darksky): get workflow from database
private async getWorkflow(graphName: string): Promise<WorkflowGraph> {
const graph = WorkflowGraphs.find(g => g.name === graphName);
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
Expand All @@ -63,14 +60,13 @@ export class CopilotWorkflowService {
options?: CopilotChatOptions
): AsyncIterable<string> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(
this.prompt,
this.provider,
workflowGraph
);
const workflow = new CopilotWorkflow(workflowGraph);

for await (const result of workflow.runGraph(params, options)) {
yield result;
}
}
}

export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
export { WorkflowNodeType } from './types';
Loading

0 comments on commit fe89ecb

Please sign in to comment.