diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index c84534bbc04ad..8758f6af80a94 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -80,12 +80,14 @@ "on-headers": "^1.0.2", "openai": "^4.33.0", "parse-duration": "^1.1.0", + "piscina": "^4.5.1", "pretty-time": "^1.1.0", "prisma": "^5.12.1", "prom-client": "^15.1.1", "reflect-metadata": "^0.2.2", "rxjs": "^7.8.1", "semver": "^7.6.0", + "ses": "^1.4.1", "socket.io": "^4.7.5", "stripe": "^15.0.0", "ts-node": "^10.9.2", diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index c4bdd7c866b63..83f3da247d81c 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -514,8 +514,8 @@ content: {{content}}`, ], }, { - name: 'workflow:presentation:step4', - action: 'workflow:presentation:step4', + name: 'workflow:presentation:step5', + action: 'workflow:presentation:step5', model: 'gpt-4o', messages: [ { diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index ee8a8490010b4..79d0b8cb52665 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -22,7 +22,7 @@ import { } from './resolver'; import { ChatSessionService } from './session'; import { CopilotStorage } from './storage'; -import { CopilotWorkflowService } from './workflow'; +import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow'; registerCopilotProvider(FalProvider); registerCopilotProvider(OpenAIProvider); @@ -41,6 +41,7 @@ registerCopilotProvider(OpenAIProvider); CopilotStorage, PromptsManagementResolver, CopilotWorkflowService, + ...CopilotWorkflowExecutors, ], controllers: [CopilotController], contributesTo: ServerFeature.Copilot, diff --git a/packages/backend/server/src/plugins/copilot/workflow/executor/chat-text.ts b/packages/backend/server/src/plugins/copilot/workflow/executor/chat-text.ts new file mode 100644 index 0000000000000..1776b69d36fae --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/workflow/executor/chat-text.ts @@ -0,0 +1,95 @@ +import { Injectable } from '@nestjs/common'; + +import { ChatPrompt, PromptService } from '../../prompt'; +import { CopilotProviderService } from '../../providers'; +import { CopilotChatOptions, CopilotTextProvider } from '../../types'; +import { + NodeData, + WorkflowNodeType, + WorkflowResult, + WorkflowResultType, +} from '../types'; +import { WorkflowExecutorType } from './types'; +import { AutoRegisteredWorkflowExecutor } from './utils'; + +@Injectable() +export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor { + constructor( + private readonly promptService: PromptService, + private readonly providerService: CopilotProviderService + ) { + super(); + } + + private async initExecutor( + data: NodeData + ): Promise< + [ + NodeData & { nodeType: WorkflowNodeType.Basic }, + ChatPrompt, + CopilotTextProvider, + ] + > { + if (data.nodeType !== WorkflowNodeType.Basic) { + throw new Error( + `Executor ${this.type} not support ${data.nodeType} node` + ); + } + + const prompt = await this.promptService.get(data.promptName); + if (!prompt) { + throw new Error( + `Prompt ${data.promptName} not found when running workflow node ${data.name}` + ); + } + const provider = await this.providerService.getProviderByModel( + prompt.model + ); + if (provider && 'generateText' in provider) { + return [data, prompt, provider]; + } + + throw new Error( + `Provider not found for model ${prompt.model} when running workflow node ${data.name}` + ); + } + + override get type() { + return WorkflowExecutorType.ChatText; + } + + override async *next( + data: NodeData, + params: Record, + options?: CopilotChatOptions + ): AsyncIterable { + const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data); + + const finalMessage = prompt.finish(params); + if (paramKey) { + // update params with custom key + yield { + type: WorkflowResultType.Params, + params: { + [paramKey]: await provider.generateText( + finalMessage, + prompt.model, + options + ), + }, + }; + } else { + for await (const content of provider.generateTextStream( + finalMessage, + prompt.model, + options + )) { + yield { + type: WorkflowResultType.Content, + nodeId: id, + content, + }; + } + } + } +} diff --git a/packages/backend/server/src/plugins/copilot/workflow/executor/index.ts b/packages/backend/server/src/plugins/copilot/workflow/executor/index.ts new file mode 100644 index 0000000000000..8f10845e34f8b --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/workflow/executor/index.ts @@ -0,0 +1,7 @@ +import { CopilotChatTextExecutor } from './chat-text'; + +export const CopilotWorkflowExecutors = [CopilotChatTextExecutor]; + +export { type WorkflowExecutor, WorkflowExecutorType } from './types'; +export { getWorkflowExecutor } from './utils'; +export { CopilotChatTextExecutor }; diff --git a/packages/backend/server/src/plugins/copilot/workflow/executor/types.ts b/packages/backend/server/src/plugins/copilot/workflow/executor/types.ts new file mode 100644 index 0000000000000..bf4f1ddd3aa01 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/workflow/executor/types.ts @@ -0,0 +1,15 @@ +import { CopilotChatOptions } from '../../types'; +import { NodeData, WorkflowResult } from '../types'; + +export enum WorkflowExecutorType { + ChatText = 'ChatText', +} + +export abstract class WorkflowExecutor { + abstract get type(): WorkflowExecutorType; + abstract next( + data: NodeData, + params: Record, + options?: CopilotChatOptions + ): AsyncIterable; +} diff --git a/packages/backend/server/src/plugins/copilot/workflow/executor/utils.ts b/packages/backend/server/src/plugins/copilot/workflow/executor/utils.ts new file mode 100644 index 0000000000000..464a405454edf --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/workflow/executor/utils.ts @@ -0,0 +1,40 @@ +import { Logger, OnModuleInit } from '@nestjs/common'; + +import { WorkflowExecutor, type WorkflowExecutorType } from './types'; + +const WORKFLOW_EXECUTOR: Map = new Map(); + +function registerWorkflowExecutor(e: WorkflowExecutor) { + const existing = WORKFLOW_EXECUTOR.get(e.type); + if (existing && existing === e) return false; + WORKFLOW_EXECUTOR.set(e.type, e); + return true; +} + +export function getWorkflowExecutor( + type: WorkflowExecutorType +): WorkflowExecutor { + const executor = WORKFLOW_EXECUTOR.get(type); + if (!executor) { + throw new Error(`Executor ${type} not defined`); + } + + return executor; +} + +export abstract class AutoRegisteredWorkflowExecutor + extends WorkflowExecutor + implements OnModuleInit +{ + onModuleInit() { + this.register(); + } + + register() { + if (registerWorkflowExecutor(this)) { + new Logger(`CopilotWorkflowExecutor:${this.type}`).log( + 'Workflow executor registered.' + ); + } + } +} diff --git a/packages/backend/server/src/plugins/copilot/workflow/graph.ts b/packages/backend/server/src/plugins/copilot/workflow/graph.ts index 2b995b03e1f13..d83da49532f04 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/graph.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/graph.ts @@ -1,6 +1,8 @@ -import { type WorkflowGraphList, WorkflowNodeType } from './types'; +import { WorkflowExecutorType } from './executor'; +import type { WorkflowGraphs } from './types'; +import { WorkflowNodeState, WorkflowNodeType } from './types'; -export const WorkflowGraphs: WorkflowGraphList = [ +export const WorkflowGraphList: WorkflowGraphs = [ { name: 'presentation', graph: [ @@ -8,7 +10,7 @@ export const WorkflowGraphs: WorkflowGraphList = [ id: 'start', name: 'Start: check language', nodeType: WorkflowNodeType.Basic, - type: 'text', + type: WorkflowExecutorType.ChatText, promptName: 'workflow:presentation:step1', paramKey: 'language', edges: ['step2'], @@ -17,49 +19,41 @@ export const WorkflowGraphs: WorkflowGraphList = [ id: 'step2', name: 'Step 2: generate presentation', nodeType: WorkflowNodeType.Basic, - type: 'text', + type: WorkflowExecutorType.ChatText, promptName: 'workflow:presentation:step2', + edges: ['step3'], + }, + { + id: 'step3', + name: 'Step 3: check format', + nodeType: WorkflowNodeType.Basic, + type: WorkflowExecutorType.ChatText, + promptName: 'workflow:presentation:step3', + paramKey: 'needFormat', + edges: ['step4'], + }, + { + id: 'step4', + name: 'Step 4: format presentation if needed', + nodeType: WorkflowNodeType.Decision, + condition: (nodeIds: string[], params: WorkflowNodeState) => + nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')], + edges: ['step5', 'step6'], + }, + { + id: 'step5', + name: 'Step 5: format presentation', + nodeType: WorkflowNodeType.Basic, + type: WorkflowExecutorType.ChatText, + promptName: 'workflow:presentation:step5', + edges: ['step6'], + }, + { + id: 'step6', + name: 'Step 6: finish', + nodeType: WorkflowNodeType.Nope, edges: [], - // edges: ['step3'], }, - // { - // id: 'step3', - // name: 'Step 3: check format', - // nodeType: WorkflowNodeType.Basic, - // type: 'text', - // promptName: 'workflow:presentation:step3', - // paramKey: 'needFormat', - // edges: ['step4'], - // }, - // { - // id: 'step4', - // name: 'Step 4: format presentation if needed', - // nodeType: WorkflowNodeType.Decision, - // condition: (( - // nodeIds: string[], - // params: WorkflowNodeState - // ) => - // nodeIds[ - // Number(String(params.needFormat).toLowerCase() === 'true') - // ]).toString(), - // edges: ['step5', 'step6'], - // }, - // { - // id: 'step5', - // name: 'Step 5: format presentation', - // nodeType: WorkflowNodeType.Basic, - // type: 'text', - // promptName: 'workflow:presentation:step5', - // edges: ['step6'], - // }, - // { - // id: 'step6', - // name: 'Step 6: finish', - // nodeType: WorkflowNodeType.Basic, - // type: 'text', - // promptName: 'workflow:presentation:step6', - // edges: [], - // }, ], }, ]; diff --git a/packages/backend/server/src/plugins/copilot/workflow/index.ts b/packages/backend/server/src/plugins/copilot/workflow/index.ts index a3c491073cb05..8feefbfe4dc32 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/index.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/index.ts @@ -1,31 +1,26 @@ import { Injectable, Logger } from '@nestjs/common'; -import { PromptService } from '../prompt'; -import { CopilotProviderService } from '../providers'; import { CopilotChatOptions } from '../types'; -import { WorkflowGraphs } from './graph'; +import { WorkflowGraphList } from './graph'; import { WorkflowNode } from './node'; -import { WorkflowGraph, WorkflowGraphList } from './types'; +import type { WorkflowGraph, WorkflowGraphInstances } from './types'; import { CopilotWorkflow } from './workflow'; @Injectable() export class CopilotWorkflowService { private readonly logger = new Logger(CopilotWorkflowService.name); - constructor( - private readonly prompt: PromptService, - private readonly provider: CopilotProviderService - ) {} + constructor() {} - private initWorkflow({ name, graph }: WorkflowGraphList[number]) { - const workflow = new Map(); - for (const nodeData of graph) { + private initWorkflow(graph: WorkflowGraph) { + const workflow = new Map(); + for (const nodeData of graph.graph) { const { edges: _, ...data } = nodeData; - const node = new WorkflowNode(data); + const node = new WorkflowNode(graph, data); workflow.set(node.id, node); } // add edges - for (const nodeData of graph) { + for (const nodeData of graph.graph) { const node = workflow.get(nodeData.id); if (!node) { this.logger.error( @@ -47,9 +42,11 @@ export class CopilotWorkflowService { return workflow; } - // TODO(@darksky): get workflow from database - private async getWorkflow(graphName: string): Promise { - const graph = WorkflowGraphs.find(g => g.name === graphName); + // TODO(@darkskygit): get workflow from database + private async getWorkflow( + graphName: string + ): Promise { + const graph = WorkflowGraphList.find(g => g.name === graphName); if (!graph) { throw new Error(`Graph ${graphName} not found`); } @@ -63,14 +60,13 @@ export class CopilotWorkflowService { options?: CopilotChatOptions ): AsyncIterable { const workflowGraph = await this.getWorkflow(graphName); - const workflow = new CopilotWorkflow( - this.prompt, - this.provider, - workflowGraph - ); + const workflow = new CopilotWorkflow(workflowGraph); for await (const result of workflow.runGraph(params, options)) { yield result; } } } + +export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor'; +export { WorkflowNodeType } from './types'; diff --git a/packages/backend/server/src/plugins/copilot/workflow/node.ts b/packages/backend/server/src/plugins/copilot/workflow/node.ts index 4d7c5f76191db..775a9f3d54012 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/node.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/node.ts @@ -1,21 +1,73 @@ -import { ChatPrompt, PromptService } from '../prompt'; -import { CopilotProviderService } from '../providers'; -import { CopilotAllProvider, CopilotChatOptions } from '../types'; -import { +import path, { dirname } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +import { Logger } from '@nestjs/common'; +import Piscina from 'piscina'; + +import { CopilotChatOptions } from '../types'; +import { getWorkflowExecutor, WorkflowExecutor } from './executor'; +import type { NodeData, + WorkflowGraph, WorkflowNodeState, - WorkflowNodeType, WorkflowResult, - WorkflowResultType, } from './types'; +import { WorkflowNodeType, WorkflowResultType } from './types'; export class WorkflowNode { + private readonly logger = new Logger(WorkflowNode.name); private readonly edges: WorkflowNode[] = []; private readonly parents: WorkflowNode[] = []; - private prompt: ChatPrompt | null = null; - private provider: CopilotAllProvider | null = null; - - constructor(private readonly data: NodeData) {} + private readonly executor: WorkflowExecutor | null = null; + private readonly condition: + | ((params: WorkflowNodeState) => Promise) + | null = null; + + constructor( + graph: WorkflowGraph, + private readonly data: NodeData + ) { + if (data.nodeType === WorkflowNodeType.Basic) { + this.executor = getWorkflowExecutor(data.type); + } else if (data.nodeType === WorkflowNodeType.Decision) { + // prepare decision condition, reused in each run + const iife = `(${data.condition})(nodeIds, params)`; + // only eval the condition in worker if graph has been modified + if (graph.modified) { + const worker = new Piscina({ + filename: path.resolve( + dirname(fileURLToPath(import.meta.url)), + 'worker.mjs' + ), + minThreads: 2, + // empty envs from parent process + env: {}, + argv: [], + execArgv: [], + }); + this.condition = (params: WorkflowNodeState) => + worker.run({ + iife, + nodeIds: this.edges.map(node => node.id), + params, + }); + } else { + const func = + typeof data.condition === 'function' + ? data.condition + : new Function( + 'nodeIds', + 'params', + `(${data.condition})(nodeIds, params)` + ); + this.condition = (params: WorkflowNodeState) => + func( + this.edges.map(node => node.id), + params + ); + } + } + } get id(): string { return this.data.id; @@ -33,6 +85,11 @@ export class WorkflowNode { return this.parents; } + // if is the end of the workflow, pass through the content to stream response + get hasEdges(): boolean { + return !!this.edges.length; + } + private set parent(node: WorkflowNode) { if (!this.parents.includes(node)) { this.parents.push(node); @@ -44,7 +101,10 @@ export class WorkflowNode { if (this.edges.length > 0) { throw new Error(`Basic block can only have one edge`); } - } else if (!this.data.condition) { + } else if ( + this.data.nodeType === WorkflowNodeType.Decision && + !this.data.condition + ) { throw new Error(`Decision block must have a condition`); } node.parent = this; @@ -52,84 +112,34 @@ export class WorkflowNode { return this.edges.length; } - async initNode(prompt: PromptService, provider: CopilotProviderService) { - if (this.prompt && this.provider) return; - - if (this.data.nodeType === WorkflowNodeType.Basic) { - this.prompt = await prompt.get(this.data.promptName); - if (!this.prompt) { - throw new Error( - `Prompt ${this.data.promptName} not found when running workflow node ${this.name}` - ); - } - this.provider = await provider.getProviderByModel(this.prompt.model); - if (!this.provider) { - throw new Error( - `Provider not found for model ${this.prompt.model} when running workflow node ${this.name}` - ); - } - } - } - private async evaluateCondition( - _condition?: string + params: WorkflowNodeState ): Promise { - // TODO(@darksky): evaluate condition to impl decision block - return this.edges[0]?.id; - } - - private getStreamProvider() { - if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) { - if ( - this.data.type === 'text' && - 'generateText' in this.provider && - !this.data.paramKey - ) { - return this.provider.generateTextStream.bind(this.provider); - } else if ( - this.data.type === 'image' && - 'generateImages' in this.provider && - !this.data.paramKey - ) { - return this.provider.generateImagesStream.bind(this.provider); - } - } - throw new Error(`Stream Provider not found for node ${this.name}`); - } - - private getProvider() { - if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) { - if ( - this.data.type === 'text' && - 'generateText' in this.provider && - this.data.paramKey - ) { - return this.provider.generateText.bind(this.provider); - } else if ( - this.data.type === 'image' && - 'generateImages' in this.provider && - this.data.paramKey - ) { - return this.provider.generateImages.bind(this.provider); - } + // early return if no edges + if (this.edges.length === 0) return undefined; + try { + const result = await this.condition?.(params); + if (typeof result === 'string') return result; + // choose default edge if condition falsy + return this.edges[0].id; + } catch (e) { + this.logger.error( + `Failed to evaluate condition for node ${this.name}: ${e}` + ); + throw e; } - throw new Error(`Provider not found for node ${this.name}`); } async *next( params: WorkflowNodeState, options?: CopilotChatOptions ): AsyncIterable { - if (!this.prompt || !this.provider) { - throw new Error(`Node ${this.name} not initialized`); - } - yield { type: WorkflowResultType.StartRun, nodeId: this.id }; // choose next node in graph let nextNode: WorkflowNode | undefined = this.edges[0]; if (this.data.nodeType === WorkflowNodeType.Decision) { - const nextNodeId = await this.evaluateCondition(this.data.condition); + const nextNodeId = await this.evaluateCondition(params); // return empty to choose default edge if (nextNodeId) { nextNode = this.edges.find(node => node.id === nextNodeId); @@ -137,37 +147,18 @@ export class WorkflowNode { throw new Error(`No edge found for condition ${this.data.condition}`); } } - } else { - const finalMessage = this.prompt.finish(params); - if (this.data.paramKey) { - const provider = this.getProvider(); - // update params with custom key - yield { - type: WorkflowResultType.Params, - params: { - [this.data.paramKey]: await provider( - finalMessage, - this.prompt.model, - options - ), - }, - }; - } else { - const provider = this.getStreamProvider(); - for await (const content of provider( - finalMessage, - this.prompt.model, - options - )) { - yield { - type: WorkflowResultType.Content, - nodeId: this.id, - content, - // pass through content as a stream response if no next node - passthrough: !nextNode, - }; - } + } else if (this.data.nodeType === WorkflowNodeType.Basic) { + if (!this.executor) { + throw new Error(`Node ${this.name} not initialized`); } + + yield* this.executor.next(this.data, params, options); + } else { + yield { + type: WorkflowResultType.Content, + nodeId: this.id, + content: params.content, + }; } yield { type: WorkflowResultType.EndRun, nextNode }; diff --git a/packages/backend/server/src/plugins/copilot/workflow/types.ts b/packages/backend/server/src/plugins/copilot/workflow/types.ts index 8fe91434b3d24..509b814074ae5 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/types.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/types.ts @@ -1,28 +1,40 @@ +import type { WorkflowExecutorType } from './executor'; import type { WorkflowNode } from './node'; export enum WorkflowNodeType { - Basic, - Decision, + Basic = 'basic', + Decision = 'decision', + Nope = 'nope', } export type NodeData = { id: string; name: string } & ( | { nodeType: WorkflowNodeType.Basic; promptName: string; - type: 'text' | 'image'; + type: WorkflowExecutorType; // update the prompt params by output with the custom key paramKey?: string; } - | { nodeType: WorkflowNodeType.Decision; condition: string } + | { + nodeType: WorkflowNodeType.Decision; + condition: + | ((nodeIds: string[], params: WorkflowNodeState) => string) + | string; + } + // do nothing node + | { nodeType: WorkflowNodeType.Nope } ); export type WorkflowNodeState = Record; export type WorkflowGraphData = Array; -export type WorkflowGraphList = Array<{ +export type WorkflowGraph = { name: string; + // true if the graph has been modified + modified?: boolean; graph: WorkflowGraphData; -}>; +}; +export type WorkflowGraphs = Array; export enum WorkflowResultType { StartRun, @@ -33,7 +45,7 @@ export enum WorkflowResultType { export type WorkflowResult = | { type: WorkflowResultType.StartRun; nodeId: string } - | { type: WorkflowResultType.EndRun; nextNode: WorkflowNode } + | { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode } | { type: WorkflowResultType.Params; params: Record; @@ -42,8 +54,6 @@ export type WorkflowResult = type: WorkflowResultType.Content; nodeId: string; content: string; - // if is the end of the workflow, pass through the content to stream response - passthrough?: boolean; }; -export type WorkflowGraph = Map; +export type WorkflowGraphInstances = Map; diff --git a/packages/backend/server/src/plugins/copilot/workflow/worker.mjs b/packages/backend/server/src/plugins/copilot/workflow/worker.mjs new file mode 100644 index 0000000000000..339e5c3d107bf --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/workflow/worker.mjs @@ -0,0 +1,11 @@ +import 'ses'; + +lockdown(); + +const sandbox = new Compartment(); + +export default ({ iife, nodeIds, params }) => { + sandbox.globalThis.nodeIds = harden(nodeIds); + sandbox.globalThis.params = harden(params); + return sandbox.evaluate(iife); +}; diff --git a/packages/backend/server/src/plugins/copilot/workflow/workflow.ts b/packages/backend/server/src/plugins/copilot/workflow/workflow.ts index b6011cd328035..2a453357d8470 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/workflow.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/workflow.ts @@ -1,12 +1,10 @@ import { Logger } from '@nestjs/common'; -import { PromptService } from '../prompt'; -import { CopilotProviderService } from '../providers'; import { CopilotChatOptions } from '../types'; import { WorkflowNode } from './node'; import { - WorkflowGraph, - WorkflowNodeState, + type WorkflowGraphInstances, + type WorkflowNodeState, WorkflowNodeType, WorkflowResultType, } from './types'; @@ -15,11 +13,7 @@ export class CopilotWorkflow { private readonly logger = new Logger(CopilotWorkflow.name); private readonly rootNode: WorkflowNode; - constructor( - private readonly prompt: PromptService, - private readonly provider: CopilotProviderService, - workflow: WorkflowGraph - ) { + constructor(workflow: WorkflowGraphInstances) { const startNode = workflow.get('start'); if (!startNode) { throw new Error(`No start node found in graph`); @@ -38,8 +32,6 @@ export class CopilotWorkflow { let result = ''; let nextNode: WorkflowNode | undefined; - await currentNode.initNode(this.prompt, this.provider); - for await (const ret of currentNode.next(lastParams, options)) { if (ret.type === WorkflowResultType.EndRun) { nextNode = ret.nextNode; @@ -53,8 +45,8 @@ export class CopilotWorkflow { ); } } else if (ret.type === WorkflowResultType.Content) { - if (ret.passthrough) { - // pass through content as a stream response + if (!currentNode.hasEdges) { + // pass through content as a stream response if node is end node yield ret.content; } else { result += ret.content; @@ -70,7 +62,9 @@ export class CopilotWorkflow { } currentNode = nextNode; - if (result) lastParams.content = result; + if (result && lastParams.content !== result) { + lastParams.content = result; + } } } } diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index 500768800665a..ef29369ed27de 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -259,7 +259,7 @@ test('should be able to chat with api by workflow', async t => { const ret = await chatWithWorkflow(app, token, sessionId, messageId); t.is( ret, - textToEventStream('generate text to text stream', messageId), + textToEventStream(['generate text to text stream'], messageId), 'should be able to chat with workflow' ); }); diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 7fd9f40931ad8..71b0fd4efe820 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -3,6 +3,7 @@ import { TestingModule } from '@nestjs/testing'; import type { TestFn } from 'ava'; import ava from 'ava'; +import Sinon from 'sinon'; import { AuthService } from '../src/core/auth'; import { QuotaModule } from '../src/core/quota'; @@ -21,7 +22,20 @@ import { CopilotCapability, CopilotProviderType, } from '../src/plugins/copilot/types'; -import { CopilotWorkflowService } from '../src/plugins/copilot/workflow'; +import { + CopilotChatTextExecutor, + CopilotWorkflowService, + WorkflowNodeType, +} from '../src/plugins/copilot/workflow'; +import { + getWorkflowExecutor, + WorkflowExecutorType, +} from '../src/plugins/copilot/workflow/executor'; +import { WorkflowGraphList } from '../src/plugins/copilot/workflow/graph'; +import { + NodeData, + WorkflowResultType, +} from '../src/plugins/copilot/workflow/types'; import { createTestingModule } from './utils'; import { MockCopilotTestProvider } from './utils/copilot'; @@ -32,6 +46,7 @@ const test = ava as TestFn<{ provider: CopilotProviderService; session: ChatSessionService; workflow: CopilotWorkflowService; + textWorkflowExecutor: CopilotChatTextExecutor; }>; test.beforeEach(async t => { @@ -59,6 +74,7 @@ test.beforeEach(async t => { const provider = module.get(CopilotProviderService); const session = module.get(ChatSessionService); const workflow = module.get(CopilotWorkflowService); + const textWorkflowExecutor = module.get(CopilotChatTextExecutor); t.context.module = module; t.context.auth = auth; @@ -66,6 +82,7 @@ test.beforeEach(async t => { t.context.provider = provider; t.context.session = session; t.context.workflow = workflow; + t.context.textWorkflowExecutor = textWorkflowExecutor; }); test.afterEach.always(async t => { @@ -541,10 +558,14 @@ test('should be able to register test provider', async t => { await assertProvider(CopilotCapability.ImageToText); }); +// ==================== workflow ==================== + // this test used to preview the final result of the workflow // for the functional test of the API itself, refer to the follow tests test.skip('should be able to preview workflow', async t => { - const { prompt, workflow } = t.context; + const { prompt, workflow, textWorkflowExecutor } = t.context; + + textWorkflowExecutor.register(); registerCopilotProvider(OpenAIProvider); for (const p of prompts) { @@ -554,13 +575,174 @@ test.skip('should be able to preview workflow', async t => { let result = ''; for await (const ret of workflow.runGraph( { content: 'apple company' }, - 'workflow:presentation' + 'presentation' )) { result += ret; console.log('stream result:', ret); } console.log('final stream result:', result); + t.truthy(result, 'should return result'); + + unregisterCopilotProvider(OpenAIProvider.type); +}); + +test('should be able to run workflow', async t => { + const { prompt, workflow, textWorkflowExecutor } = t.context; + + textWorkflowExecutor.register(); + unregisterCopilotProvider(OpenAIProvider.type); + registerCopilotProvider(MockCopilotTestProvider); + + const executor = Sinon.spy(textWorkflowExecutor, 'next'); + + for (const p of prompts) { + await prompt.set(p.name, p.model, p.messages); + } + + const graphName = 'presentation'; + const graph = WorkflowGraphList.find(g => g.name === graphName); + t.truthy(graph, `graph ${graphName} not defined`); + + // todo: use Array.fromAsync + let result = ''; + for await (const ret of workflow.runGraph( + { content: 'apple company' }, + graphName + )) { + result += ret; + } + t.assert(result, 'generate text to text stream'); + + // presentation workflow has condition node, it will always false + // so the latest 2 nodes will not be executed + const callCount = graph!.graph.length - 3; + t.is( + executor.callCount, + callCount, + `should call executor ${callCount} times` + ); + + for (const [idx, node] of graph!.graph + .filter(g => g.nodeType === WorkflowNodeType.Basic) + .entries()) { + const params = executor.getCall(idx); + + if (idx < callCount) { + t.is(params.args[0].id, node.id, 'graph id should correct'); + + t.is( + params.args[1].content, + 'generate text to text stream', + 'graph params should correct' + ); + t.is( + params.args[1].language, + 'generate text to text', + 'graph params should correct' + ); + } + } + + unregisterCopilotProvider(MockCopilotTestProvider.type); + registerCopilotProvider(OpenAIProvider); +}); +// ==================== workflow executor ==================== + +const wrapAsyncIter = async (iter: AsyncIterable) => { + const result: T[] = []; + for await (const r of iter) { + result.push(r); + } + return result; +}; + +test('should be able to run executor', async t => { + const { textWorkflowExecutor } = t.context; + + textWorkflowExecutor.register(); + const executor = getWorkflowExecutor(textWorkflowExecutor.type); + t.is(executor.type, textWorkflowExecutor.type, 'should get executor'); + + await t.throwsAsync( + wrapAsyncIter( + executor.next( + { id: 'nope', name: 'nope', nodeType: WorkflowNodeType.Nope }, + {} + ) + ), + { instanceOf: Error }, + 'should throw error if run non basic node' + ); +}); + +test('should be able to run text executor', async t => { + const { textWorkflowExecutor, provider, prompt } = t.context; + + textWorkflowExecutor.register(); + const executor = getWorkflowExecutor(textWorkflowExecutor.type); unregisterCopilotProvider(OpenAIProvider.type); - t.pass(); + registerCopilotProvider(MockCopilotTestProvider); + await prompt.set('test', 'test', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + // mock provider + const testProvider = + (await provider.getProviderByModel('test'))!; + const text = Sinon.spy(testProvider, 'generateText'); + const textStream = Sinon.spy(testProvider, 'generateTextStream'); + + const nodeData: NodeData = { + id: 'basic', + name: 'basic', + nodeType: WorkflowNodeType.Basic, + promptName: 'test', + type: WorkflowExecutorType.ChatText, + }; + + // text + { + const ret = await wrapAsyncIter( + executor.next({ ...nodeData, paramKey: 'key' }, { word: 'world' }) + ); + + t.deepEqual(ret, [ + { + type: WorkflowResultType.Params, + params: { key: 'generate text to text' }, + }, + ]); + t.deepEqual( + text.lastCall.args[0][0].content, + 'hello world', + 'should render the prompt with params' + ); + } + + // text stream with attachment + { + const ret = await wrapAsyncIter( + executor.next(nodeData, { + attachments: ['https://affine.pro/example.jpg'], + }) + ); + + t.deepEqual( + ret, + Array.from('generate text to text stream').map(t => ({ + content: t, + nodeId: 'basic', + type: WorkflowResultType.Content, + })) + ); + t.deepEqual( + textStream.lastCall.args[0][0].params?.attachments, + ['https://affine.pro/example.jpg'], + 'should pass attachments to provider' + ); + } + + Sinon.restore(); + unregisterCopilotProvider(MockCopilotTestProvider.type); + registerCopilotProvider(OpenAIProvider); }); diff --git a/yarn.lock b/yarn.lock index d885fa87347e6..cce0b1b37786f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -806,12 +806,14 @@ __metadata: on-headers: "npm:^1.0.2" openai: "npm:^4.33.0" parse-duration: "npm:^1.1.0" + piscina: "npm:^4.5.1" pretty-time: "npm:^1.1.0" prisma: "npm:^5.12.1" prom-client: "npm:^15.1.1" reflect-metadata: "npm:^0.2.2" rxjs: "npm:^7.8.1" semver: "npm:^7.6.0" + ses: "npm:^1.4.1" sinon: "npm:^18.0.0" socket.io: "npm:^4.7.5" stripe: "npm:^15.0.0" @@ -30889,6 +30891,17 @@ __metadata: languageName: node linkType: hard +"nice-napi@npm:^1.0.2": + version: 1.0.2 + resolution: "nice-napi@npm:1.0.2" + dependencies: + node-addon-api: "npm:^3.0.0" + node-gyp: "npm:latest" + node-gyp-build: "npm:^4.2.2" + conditions: "!os=win32" + languageName: node + linkType: hard + "nice-try@npm:^1.0.4": version: 1.0.5 resolution: "nice-try@npm:1.0.5" @@ -30935,6 +30948,15 @@ __metadata: languageName: node linkType: hard +"node-addon-api@npm:^3.0.0": + version: 3.2.1 + resolution: "node-addon-api@npm:3.2.1" + dependencies: + node-gyp: "npm:latest" + checksum: 10/681b52dfa3e15b0a8e5cf283cc0d8cd5fd2a57c559ae670fcfd20544cbb32f75de7648674110defcd17ab2c76ebef630aa7d2d2f930bc7a8cc439b20fe233518 + languageName: node + linkType: hard + "node-api-version@npm:^0.2.0": version: 0.2.0 resolution: "node-api-version@npm:0.2.0" @@ -32505,6 +32527,18 @@ __metadata: languageName: node linkType: hard +"piscina@npm:^4.5.1": + version: 4.5.1 + resolution: "piscina@npm:4.5.1" + dependencies: + nice-napi: "npm:^1.0.2" + dependenciesMeta: + nice-napi: + optional: true + checksum: 10/a450140a3266920417844c11164bba208d4df2ddcb683bf3ee99d8d4b7db9de9751f5b83f6f5b2146391d18012d9f4af30a69ed751b9d4aa5eea1771a37df275 + languageName: node + linkType: hard + "pkg-dir@npm:^3.0.0": version: 3.0.0 resolution: "pkg-dir@npm:3.0.0"