From 95a0680ec61d8d9803c6cea0b2732abebf046461 Mon Sep 17 00:00:00 2001 From: dosco <832235+dosco@users.noreply.github.com> Date: Thu, 20 Jun 2024 00:40:10 -0700 Subject: [PATCH] feat: added multi-modal support to anthropic api and other fixes --- .cspell/project-words.txt | 1 + package.json | 2 +- src/ai/anthropic/api.ts | 229 ++++++++++++++++++++++++++---------- src/ai/anthropic/types.ts | 128 ++++++++++---------- src/ai/base.ts | 3 + src/ai/cohere/api.ts | 2 - src/ai/google-gemini/api.ts | 66 ++++++----- src/ai/ollama/api.ts | 27 +++-- src/ai/openai/api.ts | 49 ++++---- src/ai/util.ts | 10 +- src/dsp/generate.ts | 11 +- src/examples/streaming1.ts | 9 +- src/examples/streaming2.ts | 13 +- src/examples/summarize.ts | 2 +- src/mem/memory.ts | 24 +++- src/prompts/prompts.test.ts | 2 +- 16 files changed, 367 insertions(+), 211 deletions(-) diff --git a/.cspell/project-words.txt b/.cspell/project-words.txt index 173ecd2b..f4325545 100644 --- a/.cspell/project-words.txt +++ b/.cspell/project-words.txt @@ -31,6 +31,7 @@ Logprob logprobs Logprobs Macbook +minilm Mixtral nanos neumann diff --git a/package.json b/package.json index 3efe71dd..7599fedb 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ax-llm/ax", - "version": "9.0.2", + "version": "9.0.3", "type": "module", "description": "The best library to work with LLMs", "typings": "build/module/src/index.d.ts", diff --git a/src/ai/anthropic/api.ts b/src/ai/anthropic/api.ts index 70d81225..241e9489 100644 --- a/src/ai/anthropic/api.ts +++ b/src/ai/anthropic/api.ts @@ -4,7 +4,6 @@ import type { AxAIServiceOptions, AxChatRequest, AxChatResponse, - AxChatResponseResult, AxModelConfig } from '../types.js'; @@ -17,6 +16,7 @@ import { type AxAnthropicConfig, type AxAnthropicContentBlockDeltaEvent, type AxAnthropicContentBlockStartEvent, + type AxAnthropicErrorEvent, type AxAnthropicMessageDeltaEvent, type AxAnthropicMessageStartEvent, AxAnthropicModel @@ -56,7 +56,6 @@ export class AxAnthropic extends AxBaseAI< apiURL: 'https://api.anthropic.com/v1', headers: { 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'tools-2024-04-04', 'x-api-key': apiKey }, modelInfo: axModelInfoAnthropic, @@ -86,20 +85,7 @@ export class AxAnthropic extends AxBaseAI< name: '/messages' }; - const messages = - req.chatPrompt?.map((msg) => { - if (msg.role === 'function') { - return { - role: 'user' as 'user' | 'assistant' | 'system', - content: msg.content, - tool_use_id: msg.functionId - }; - } - return { - role: msg.role as 'user' | 'assistant' | 'system', - content: msg.content ?? '' - }; - }) ?? []; + const messages = createMessages(req); const tools: AxAnthropicChatRequest['tools'] = req.functions?.map((v) => ({ name: v.name, @@ -107,6 +93,8 @@ export class AxAnthropic extends AxBaseAI< input_schema: v.parameters })); + const stream = req.modelConfig?.stream ?? this.config.stream; + const reqValue: AxAnthropicChatRequest = { model: req.modelInfo?.name ?? this.config.model, max_tokens: req.modelConfig?.maxTokens ?? this.config.maxTokens, @@ -116,6 +104,7 @@ export class AxAnthropic extends AxBaseAI< top_p: req.modelConfig?.topP ?? this.config.topP, top_k: req.modelConfig?.topK ?? this.config.topK, ...(tools && tools.length > 0 ? { tools } : {}), + ...(stream ? { stream: true } : {}), messages }; @@ -123,14 +112,12 @@ export class AxAnthropic extends AxBaseAI< }; override generateChatResp = ( - response: Readonly + resp: Readonly ): AxChatResponse => { - const err = response as AxAnthropicChatError; - if (err.type === 'error') { - throw new Error(`Anthropic Chat API Error: ${err.error.message}`); + if (resp.type === 'error') { + throw new Error(`Anthropic Chat API Error: ${resp.error.message}`); } - const resp = response as AxAnthropicChatResponse; const results = resp.content.map((msg) => { let finishReason: AxChatResponse['results'][0]['finishReason']; @@ -169,65 +156,187 @@ export class AxAnthropic extends AxBaseAI< }; override generateChatStreamResp = ( - resp: Readonly + resp: Readonly, + state: object ): AxChatResponse => { - let results: AxChatResponseResult[] = []; - let modelUsage; + if (!('type' in resp)) { + throw new Error('Invalid Anthropic streaming event'); + } + + const sstate = state as { + indexIdMap: Record; + }; + + if (!sstate.indexIdMap) { + sstate.indexIdMap = {}; + } + + if (resp.type === 'error') { + const { error } = resp as unknown as AxAnthropicErrorEvent; + throw new Error(error.message); + } - if ('message' in resp) { + if (resp.type === 'message_start') { const { message } = resp as unknown as AxAnthropicMessageStartEvent; - results = [ - { - content: '', - id: message.id - } - ]; - modelUsage = { - promptTokens: resp.usage?.input_tokens ?? 0, - completionTokens: resp.usage?.output_tokens ?? 0, + const results = [{ content: '', id: message.id }]; + const modelUsage = { + promptTokens: message.usage?.input_tokens ?? 0, + completionTokens: message.usage?.output_tokens ?? 0, totalTokens: - (resp.usage?.input_tokens ?? 0) + (resp.usage?.output_tokens ?? 0) + (message.usage?.input_tokens ?? 0) + + (message.usage?.output_tokens ?? 0) + }; + return { + results, + modelUsage }; } - if ('content_block' in resp) { - const { content_block: cb } = + if (resp.type === 'content_block_start') { + const { content_block: contentBlock } = resp as unknown as AxAnthropicContentBlockStartEvent; - results = [{ content: cb.text }]; + + if (contentBlock.type === 'text') { + return { + results: [{ content: contentBlock.text }] + }; + } + if (contentBlock.type === 'tool_use') { + if ( + typeof contentBlock.id === 'string' && + typeof resp.index === 'number' && + !sstate.indexIdMap[resp.index] + ) { + sstate.indexIdMap[resp.index] = contentBlock.id; + } + } } - if ( - 'delta' in resp && - 'text' in (resp as unknown as AxAnthropicContentBlockDeltaEvent).delta - ) { - const { delta: cb } = - resp as unknown as AxAnthropicContentBlockDeltaEvent; - results = [{ content: cb.text }]; + if (resp.type === 'content_block_delta') { + const { delta } = resp as unknown as AxAnthropicContentBlockDeltaEvent; + if (delta.type === 'text_delta') { + return { + results: [{ content: delta.text }] + }; + } + if (delta.type === 'input_json_delta') { + const id = sstate.indexIdMap[resp.index]; + if (!id) { + throw new Error('invalid streaming index no id found: ' + resp.index); + } + const functionCalls = [ + { + id, + type: 'function' as const, + function: { + name: '', + arguments: delta.partial_json + } + } + ]; + return { + results: [{ functionCalls }] + }; + } } - if ( - 'delta' in resp && - 'stop_reason' in (resp as unknown as AxAnthropicMessageDeltaEvent).delta - ) { - const { delta } = resp as unknown as AxAnthropicMessageDeltaEvent; - results = [ - { content: '', finishReason: mapFinishReason(delta.stop_reason) } - ]; - modelUsage = { - promptTokens: resp.usage?.input_tokens ?? 0, - completionTokens: resp.usage?.output_tokens ?? 0, - totalTokens: - (resp.usage?.input_tokens ?? 0) + (resp.usage?.output_tokens ?? 0) + if (resp.type === 'message_delta') { + const { delta, usage } = resp as unknown as AxAnthropicMessageDeltaEvent; + return { + results: [ + { + content: '', + finishReason: mapFinishReason(delta.stop_reason) + } + ], + modelUsage: { + promptTokens: 0, + completionTokens: usage.output_tokens, + totalTokens: usage.output_tokens + } }; } return { - results, - modelUsage + results: [{ content: '' }] }; }; } +function createMessages( + req: Readonly +): AxAnthropicChatRequest['messages'] { + return req.chatPrompt.map((msg) => { + switch (msg.role) { + case 'function': + return { + role: 'user' as const, + content: [ + { + type: 'tool_result', + text: msg.content, + tool_use_id: msg.functionId + } + ] + }; + case 'user': { + if (typeof msg.content === 'string') { + return { role: 'user' as const, content: msg.content }; + } + const content = msg.content.map((v) => { + switch (v.type) { + case 'text': + return { type: 'text' as const, text: v.text }; + case 'image': + return { + type: 'image' as const, + source: { + type: 'base64' as const, + media_type: v.mimeType, + data: v.image + } + }; + default: + throw new Error('Invalid content type'); + } + }); + return { + role: 'user' as const, + content + }; + } + case 'assistant': { + if (typeof msg.content === 'string') { + return { role: 'assistant' as const, content: msg.content }; + } + if (typeof msg.functionCalls !== 'undefined') { + const content = msg.functionCalls.map((v) => { + let input; + if (typeof v.function.arguments === 'string') { + input = JSON.parse(v.function.arguments); + } else if (typeof v.function.arguments === 'object') { + input = v.function.arguments; + } + return { + type: 'tool_use' as const, + id: v.id, + name: v.function.name, + input + }; + }); + return { + role: 'assistant' as const, + content + }; + } + throw new Error('Invalid content type'); + } + default: + throw new Error('Invalid role'); + } + }); +} + function mapFinishReason( stopReason?: AxAnthropicChatResponse['stop_reason'] | null ): AxChatResponse['results'][0]['finishReason'] | undefined { diff --git a/src/ai/anthropic/types.ts b/src/ai/anthropic/types.ts index e251e1b2..c485e149 100644 --- a/src/ai/anthropic/types.ts +++ b/src/ai/anthropic/types.ts @@ -15,22 +15,30 @@ export type AxAnthropicConfig = AxModelConfig & { // Type for the request to create a message using Anthropic's Messages API export type AxAnthropicChatRequest = { model: string; - messages: { - role: 'user' | 'assistant' | 'system'; - content: - | string - | { - type: 'text' | 'image' | 'tool_result'; - text?: string; // Text content (if type is 'text') - tool_use_id?: string; - content?: string; - source?: { - type: 'base64'; - media_type: string; - data: string; - }; - }[]; - }[]; + messages: ( + | { + role: 'user'; + content: + | string + | ( + | { type: 'text'; text: string } + | { + type: 'image'; + source: { type: 'base64'; media_type: string; data: string }; + } + | { type: 'tool_result'; text: string; tool_use_id: string } + )[]; + } + | { + role: 'assistant'; + content: + | string + | ( + | { type: 'text'; text: string } + | { type: 'tool_use'; id: string; name: string; input: object } + )[]; + } + )[]; tools?: { name: string; description: string; @@ -76,13 +84,9 @@ export type AxAnthropicChatError = { }; }; -// Base interface for all event types in the stream -export interface AxAnthropicStreamEvent { - type: string; -} - // Represents the start of a message with an empty content array -export interface AxAnthropicMessageStartEvent extends AxAnthropicStreamEvent { +export interface AxAnthropicMessageStartEvent { + type: 'message_start'; message: { id: string; type: 'message'; @@ -99,58 +103,76 @@ export interface AxAnthropicMessageStartEvent extends AxAnthropicStreamEvent { } // Indicates the start of a content block within a message -export interface AxAnthropicContentBlockStartEvent - extends AxAnthropicStreamEvent { +export interface AxAnthropicContentBlockStartEvent { index: number; - content_block: { - type: 'text'; - text: string; - }; + type: 'content_block_start'; + content_block: + | { + type: 'text'; + text: string; + } + | { + type: 'tool_use'; + id: string; + name: string; + input: object; + }; } // Represents incremental updates to a content block -export interface AxAnthropicContentBlockDeltaEvent - extends AxAnthropicStreamEvent { +export interface AxAnthropicContentBlockDeltaEvent { index: number; - delta: { - type: 'text_delta'; - text: string; - }; + type: 'content_block_delta'; + delta: + | { + type: 'text_delta'; + text: string; + } + | { + type: 'input_json_delta'; + partial_json: string; + }; } // Marks the end of a content block within a message -export interface AxAnthropicContentBlockStopEvent - extends AxAnthropicStreamEvent { +export interface AxAnthropicContentBlockStopEvent { + type: 'content_block_stop'; index: number; } // Indicates top-level changes to the final message object -export interface AxAnthropicMessageDeltaEvent extends AxAnthropicStreamEvent { +export interface AxAnthropicMessageDeltaEvent { + type: 'message_delta'; delta: { stop_reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | null; stop_sequence: string | null; - usage: { - output_tokens: number; - }; + }; + usage: { + output_tokens: number; }; } // Marks the end of a message -export type AxAnthropicMessageStopEvent = AxAnthropicStreamEvent; +export interface AxAnthropicMessageStopEvent { + type: 'message_stop'; +} // Represents a ping event, which can occur any number of times -export type AxAnthropicPingEvent = AxAnthropicStreamEvent; +export interface AxAnthropicPingEvent { + type: 'ping'; +} // Represents an error event -export interface AxAnthropicErrorEvent extends AxAnthropicStreamEvent { +export interface AxAnthropicErrorEvent { + type: 'error'; error: { - type: 'overloaded_error' | string; + type: 'overloaded_error'; message: string; }; } // Union type for all possible event types in the stream -export type AxAxAnthropicStreamEventType = +export type AxAnthropicChatResponseDelta = | AxAnthropicMessageStartEvent | AxAnthropicContentBlockStartEvent | AxAnthropicContentBlockDeltaEvent @@ -159,19 +181,3 @@ export type AxAxAnthropicStreamEventType = | AxAnthropicMessageStopEvent | AxAnthropicPingEvent | AxAnthropicErrorEvent; - -// Type for the response delta in streaming mode, using generic to allow flexibility -export interface AxAnthropicResponseDelta { - id: string; - object: 'message'; - model: string; - events: T[]; // Array of all event types that can occur in the stream - usage?: { - input_tokens: number; - output_tokens: number; - }; -} - -// Specific type for handling text deltas in the streaming response -export type AxAnthropicChatResponseDelta = - AxAnthropicResponseDelta; diff --git a/src/ai/base.ts b/src/ai/base.ts index f301a645..25bb679b 100644 --- a/src/ai/base.ts +++ b/src/ai/base.ts @@ -426,6 +426,9 @@ const logResponse = (resp: Readonly) => { }; const logStreamingResponse = (resp: Readonly) => { + if (!resp.results) { + return; + } for (const r of resp.results) { if (r.content) { process.stdout.write(colorLog.greenBright(r.content)); diff --git a/src/ai/cohere/api.ts b/src/ai/cohere/api.ts index 2eaab269..b3b11125 100644 --- a/src/ai/cohere/api.ts +++ b/src/ai/cohere/api.ts @@ -95,8 +95,6 @@ export class AxCohere extends AxBaseAI< _config: Readonly ): [API, AxCohereChatRequest] => { const model = req.modelInfo?.name ?? this.config.model; - // const functionsList = req.functions - // ? `Functions:\n${JSON.stringify(req.functions, null, 2)}\n` const lastChatMsg = req.chatPrompt.at(-1); const restOfChat = req.chatPrompt.slice(0, -1); diff --git a/src/ai/google-gemini/api.ts b/src/ai/google-gemini/api.ts index 19ab1fbb..2766ca0b 100644 --- a/src/ai/google-gemini/api.ts +++ b/src/ai/google-gemini/api.ts @@ -345,42 +345,44 @@ export class AxGoogleGemini extends AxBaseAI< override generateChatResp = ( resp: Readonly ): AxChatResponse => { - const results: AxChatResponseResult[] = resp.candidates.map((candidate) => { - const result: AxChatResponseResult = {}; - - switch (candidate.finishReason) { - case 'MAX_TOKENS': - result.finishReason = 'length'; - break; - case 'STOP': - result.finishReason = 'stop'; - break; - case 'SAFETY': - throw new Error('Finish reason: SAFETY'); - case 'RECITATION': - throw new Error('Finish reason: RECITATION'); - } - - for (const part of candidate.content.parts) { - if ('text' in part) { - result.content = part.text; - continue; + const results: AxChatResponseResult[] = resp.candidates?.map( + (candidate) => { + const result: AxChatResponseResult = {}; + + switch (candidate.finishReason) { + case 'MAX_TOKENS': + result.finishReason = 'length'; + break; + case 'STOP': + result.finishReason = 'stop'; + break; + case 'SAFETY': + throw new Error('Finish reason: SAFETY'); + case 'RECITATION': + throw new Error('Finish reason: RECITATION'); } - if ('functionCall' in part) { - result.functionCalls = [ - { - id: part.functionCall.name, - type: 'function', - function: { - name: part.functionCall.name, - arguments: part.functionCall.args + + for (const part of candidate.content.parts) { + if ('text' in part) { + result.content = part.text; + continue; + } + if ('functionCall' in part) { + result.functionCalls = [ + { + id: part.functionCall.name, + type: 'function', + function: { + name: part.functionCall.name, + arguments: part.functionCall.args + } } - } - ]; + ]; + } } + return result; } - return result; - }); + ); let modelUsage: AxTokenUsage | undefined; if (resp.usageMetadata) { diff --git a/src/ai/ollama/api.ts b/src/ai/ollama/api.ts index 5cfdd5ab..8be1b9be 100644 --- a/src/ai/ollama/api.ts +++ b/src/ai/ollama/api.ts @@ -8,24 +8,26 @@ import type { AxAIServiceOptions } from '../types.js'; export type AxOllamaAIConfig = AxOpenAIConfig; -export const axOllamaDefaultConfig = (): Omit => +export const axOllamaDefaultConfig = (): AxOllamaAIConfig => structuredClone({ - ...axBaseAIDefaultConfig() + ...axBaseAIDefaultConfig(), + model: 'nous-hermes2', + embedModel: 'all-minilm' }); -export const axOllamaDefaultCreativeConfig = (): Omit< - AxOllamaAIConfig, - 'model' -> => +export const axOllamaDefaultCreativeConfig = (): AxOllamaAIConfig => structuredClone({ - ...axBaseAIDefaultCreativeConfig() + ...axBaseAIDefaultCreativeConfig(), + model: 'nous-hermes2', + embedModel: 'all-minilm' }); export type AxOllamaArgs = { - model: string; + model?: string; + embedModel?: string; url?: string; apiKey?: string; - config?: Readonly>; + config?: Readonly; options?: Readonly; }; @@ -38,13 +40,18 @@ export class AxOllama extends AxOpenAI { apiKey = 'not-set', url = 'http://localhost:11434', model, + embedModel, config = axOllamaDefaultConfig(), options }: Readonly) { super({ apiKey, options, - config: { ...config, model }, + config: { + ...config, + ...(model ? { model } : {}), + ...(embedModel ? { embedModel } : {}) + }, apiURL: new URL('/v1', url).href }); diff --git a/src/ai/openai/api.ts b/src/ai/openai/api.ts index 83a2a331..9b6ba823 100644 --- a/src/ai/openai/api.ts +++ b/src/ai/openai/api.ts @@ -263,33 +263,32 @@ export class AxOpenAI extends AxBaseAI< ({ delta: { content, role, tool_calls }, finish_reason }) => { const finishReason = mapFinishReason(finish_reason); - const functionCalls = tool_calls - ?.map((v) => { - if ( - typeof v.id === 'string' && - typeof v.index === 'number' && - !sstate.indexIdMap[v.index] - ) { - sstate.indexIdMap[v.index] = v.id; - } + const functionCalls = tool_calls?.map((v) => { + if ( + typeof v.id === 'string' && + typeof v.index === 'number' && + !sstate.indexIdMap[v.index] + ) { + sstate.indexIdMap[v.index] = v.id; + } - const id = sstate.indexIdMap[v.index]; - if (!id) { - return null; - } + const id = sstate.indexIdMap[v.index]; + if (!id) { + throw new Error('invalid streaming index no id found: ' + v.index); + } - return { - id, - type: 'function' as const, - function: { - name: v.function.name, - arguments: v.function.arguments - } - }; - }) - .filter(Boolean) as NonNullable< - AxChatResponseResult['functionCalls'] - >; + return { + id, + type: 'function' as const, + function: { + name: v.function.name, + arguments: v.function.arguments + } + }; + }); + // .filter(Boolean) as NonNullable< + // AxChatResponseResult['functionCalls'] + // >; return { content, diff --git a/src/ai/util.ts b/src/ai/util.ts index ecefd71f..702c4425 100644 --- a/src/ai/util.ts +++ b/src/ai/util.ts @@ -73,11 +73,17 @@ export function mergeFunctionCalls( const fc = functionCalls.find((fc) => fc.id === _fc.id); if (fc) { - if (typeof _fc.function.name == 'string') { + if ( + typeof _fc.function.name == 'string' && + _fc.function.name.length > 0 + ) { fc.function.name += _fc.function.name; } - if (typeof _fc.function.arguments == 'string') { + if ( + typeof _fc.function.arguments == 'string' && + _fc.function.arguments.length > 0 + ) { fc.function.arguments += _fc.function.arguments; } diff --git a/src/dsp/generate.ts b/src/dsp/generate.ts index 22fd661e..01ad23df 100644 --- a/src/dsp/generate.ts +++ b/src/dsp/generate.ts @@ -277,7 +277,11 @@ export class AxGenerate< if (result.content) { content += result.content; - mem.updateResult({ ...result, content, functionCalls }, sessionId); + + mem.updateResult( + { name: result.name, content, functionCalls }, + sessionId + ); assertStreamingAssertions( this.streamingAsserts, @@ -302,7 +306,10 @@ export class AxGenerate< } if (funcs) { - mem.updateResult({ ...result, content, functionCalls }, sessionId); + mem.updateResult( + { name: result.name, content, functionCalls }, + sessionId + ); await this.processFunctions(funcs, mem, sessionId, traceId); } } diff --git a/src/examples/streaming1.ts b/src/examples/streaming1.ts index 35e957f8..d05dd2a6 100644 --- a/src/examples/streaming1.ts +++ b/src/examples/streaming1.ts @@ -1,10 +1,13 @@ import { axAI, AxChainOfThought, type AxOpenAIArgs } from '../index.js'; -// const ai = AI('openai', { apiKey: process.env.OPENAI_APIKEY } as AxOpenAIArgs); -const ai = axAI('google-gemini', { - apiKey: process.env.GOOGLE_APIKEY +const ai = axAI('openai', { + apiKey: process.env.OPENAI_APIKEY } as AxOpenAIArgs); +// const ai = axAI('google-gemini', { +// apiKey: process.env.GOOGLE_APIKEY +// } as AxOpenAIArgs); + // setup the prompt program const gen = new AxChainOfThought( ai, diff --git a/src/examples/streaming2.ts b/src/examples/streaming2.ts index 35e54ff1..6e8bd592 100644 --- a/src/examples/streaming2.ts +++ b/src/examples/streaming2.ts @@ -1,10 +1,13 @@ import { axAI, AxChainOfThought, type AxOpenAIArgs } from '../index.js'; -// const ai = AI('openai', { apiKey: process.env.OPENAI_APIKEY } as AxOpenAIArgs); -const ai = axAI('google-gemini', { - apiKey: process.env.GOOGLE_APIKEY +const ai = axAI('openai', { + apiKey: process.env.OPENAI_APIKEY } as AxOpenAIArgs); +// const ai = axAI('anthropic', { +// apiKey: process.env.ANTHROPIC_APIKEY +// } as AxAnthropicArgs); + // setup the prompt program const gen = new AxChainOfThought( ai, @@ -18,11 +21,11 @@ gen.addStreamingAssert( const re = /^\d+\./; // split the value by lines, trim each line, - // filter out empty lines and check if all lines match the regex + // filter out very short lines and check if all lines match the regex return value .split('\n') .map((x) => x.trim()) - .filter((x) => x.length > 0) + .filter((x) => x.length > 4) .every((x) => re.test(x)); }, 'Lines must start with a number and a dot. Eg: 1. This is a line.' diff --git a/src/examples/summarize.ts b/src/examples/summarize.ts index adc1f41a..480d5b13 100644 --- a/src/examples/summarize.ts +++ b/src/examples/summarize.ts @@ -14,7 +14,7 @@ const ai = axAI('openai', { ai.setOptions({ debug: true }); -// const ai = AI('ollama', { model: 'nous-hermes2' }); +// const ai = axAI('ollama', { model: 'nous-hermes2' }); const gen = new AxChainOfThought( ai, diff --git a/src/mem/memory.ts b/src/mem/memory.ts index 0f39c904..01677f62 100644 --- a/src/mem/memory.ts +++ b/src/mem/memory.ts @@ -2,6 +2,8 @@ import type { AxChatRequest, AxChatResponseResult } from '../ai/types.js'; import type { AxAIMemory } from './types.js'; +type Writeable = { -readonly [P in keyof T]: T[P] }; + export class AxMemory implements AxAIMemory { private data: AxChatRequest['chatPrompt'] = []; private sdata = new Map(); @@ -22,6 +24,7 @@ export class AxMemory implements AxAIMemory { ): void { const d = this.get(sessionId); let n = 0; + if (Array.isArray(value)) { n = d.push(...structuredClone(value)); } else { @@ -46,14 +49,23 @@ export class AxMemory implements AxAIMemory { sessionId?: string ): void { const item = this.get(sessionId); - if ('content' in item && content) { - item.content = content; + + const lastItem = item.at(-1) as unknown as Writeable< + AxChatRequest['chatPrompt'][0] + >; + if (!lastItem || lastItem.role !== 'assistant') { + this.addResult({ content, name, functionCalls }, sessionId); + return; + } + + if ('content' in lastItem && content) { + lastItem.content = content; } - if ('name' in item && name) { - item.name = name; + if ('name' in lastItem && name) { + lastItem.name = name; } - if ('functionCalls' in item && functionCalls) { - item.functionCalls = functionCalls; + if ('functionCalls' in lastItem && functionCalls) { + lastItem.functionCalls = functionCalls; } } diff --git a/src/prompts/prompts.test.ts b/src/prompts/prompts.test.ts index d4753de4..a364c807 100644 --- a/src/prompts/prompts.test.ts +++ b/src/prompts/prompts.test.ts @@ -48,7 +48,7 @@ test('generate prompt', async (t) => { const options = { fetch: mockFetch }; const ai = axAI('openai', { apiKey: 'no-key', options } as AxOpenAIArgs); - // const ai = AI('ollama', { model: 'nous-hermes2' }); + // const ai = axAI('ollama', { model: 'nous-hermes2' }); const gen = new AxChainOfThought( ai,