Skip to content

Commit

Permalink
feat: added multi-modal support to anthropic api and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jun 20, 2024
1 parent 12dd181 commit 95a0680
Show file tree
Hide file tree
Showing 16 changed files with 367 additions and 211 deletions.
1 change: 1 addition & 0 deletions .cspell/project-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Logprob
logprobs
Logprobs
Macbook
minilm
Mixtral
nanos
neumann
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@ax-llm/ax",
"version": "9.0.2",
"version": "9.0.3",
"type": "module",
"description": "The best library to work with LLMs",
"typings": "build/module/src/index.d.ts",
Expand Down
229 changes: 169 additions & 60 deletions src/ai/anthropic/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import type {
AxAIServiceOptions,
AxChatRequest,
AxChatResponse,
AxChatResponseResult,
AxModelConfig
} from '../types.js';

Expand All @@ -17,6 +16,7 @@ import {
type AxAnthropicConfig,
type AxAnthropicContentBlockDeltaEvent,
type AxAnthropicContentBlockStartEvent,
type AxAnthropicErrorEvent,
type AxAnthropicMessageDeltaEvent,
type AxAnthropicMessageStartEvent,
AxAnthropicModel
Expand Down Expand Up @@ -56,7 +56,6 @@ export class AxAnthropic extends AxBaseAI<
apiURL: 'https://api.anthropic.com/v1',
headers: {
'anthropic-version': '2023-06-01',
'anthropic-beta': 'tools-2024-04-04',
'x-api-key': apiKey
},
modelInfo: axModelInfoAnthropic,
Expand Down Expand Up @@ -86,27 +85,16 @@ export class AxAnthropic extends AxBaseAI<
name: '/messages'
};

const messages =
req.chatPrompt?.map((msg) => {
if (msg.role === 'function') {
return {
role: 'user' as 'user' | 'assistant' | 'system',
content: msg.content,
tool_use_id: msg.functionId
};
}
return {
role: msg.role as 'user' | 'assistant' | 'system',
content: msg.content ?? ''
};
}) ?? [];
const messages = createMessages(req);

const tools: AxAnthropicChatRequest['tools'] = req.functions?.map((v) => ({
name: v.name,
description: v.description,
input_schema: v.parameters
}));

const stream = req.modelConfig?.stream ?? this.config.stream;

const reqValue: AxAnthropicChatRequest = {
model: req.modelInfo?.name ?? this.config.model,
max_tokens: req.modelConfig?.maxTokens ?? this.config.maxTokens,
Expand All @@ -116,21 +104,20 @@ export class AxAnthropic extends AxBaseAI<
top_p: req.modelConfig?.topP ?? this.config.topP,
top_k: req.modelConfig?.topK ?? this.config.topK,
...(tools && tools.length > 0 ? { tools } : {}),
...(stream ? { stream: true } : {}),
messages
};

return [apiConfig, reqValue];
};

override generateChatResp = (
response: Readonly<AxAnthropicChatResponse | AxAnthropicChatError>
resp: Readonly<AxAnthropicChatResponse | AxAnthropicChatError>
): AxChatResponse => {
const err = response as AxAnthropicChatError;
if (err.type === 'error') {
throw new Error(`Anthropic Chat API Error: ${err.error.message}`);
if (resp.type === 'error') {
throw new Error(`Anthropic Chat API Error: ${resp.error.message}`);
}

const resp = response as AxAnthropicChatResponse;
const results = resp.content.map((msg) => {
let finishReason: AxChatResponse['results'][0]['finishReason'];

Expand Down Expand Up @@ -169,65 +156,187 @@ export class AxAnthropic extends AxBaseAI<
};

override generateChatStreamResp = (
resp: Readonly<AxAnthropicChatResponseDelta>
resp: Readonly<AxAnthropicChatResponseDelta>,
state: object
): AxChatResponse => {
let results: AxChatResponseResult[] = [];
let modelUsage;
if (!('type' in resp)) {
throw new Error('Invalid Anthropic streaming event');
}

const sstate = state as {
indexIdMap: Record<number, string>;
};

if (!sstate.indexIdMap) {
sstate.indexIdMap = {};
}

if (resp.type === 'error') {
const { error } = resp as unknown as AxAnthropicErrorEvent;
throw new Error(error.message);
}

if ('message' in resp) {
if (resp.type === 'message_start') {
const { message } = resp as unknown as AxAnthropicMessageStartEvent;
results = [
{
content: '',
id: message.id
}
];
modelUsage = {
promptTokens: resp.usage?.input_tokens ?? 0,
completionTokens: resp.usage?.output_tokens ?? 0,
const results = [{ content: '', id: message.id }];
const modelUsage = {
promptTokens: message.usage?.input_tokens ?? 0,
completionTokens: message.usage?.output_tokens ?? 0,
totalTokens:
(resp.usage?.input_tokens ?? 0) + (resp.usage?.output_tokens ?? 0)
(message.usage?.input_tokens ?? 0) +
(message.usage?.output_tokens ?? 0)
};
return {
results,
modelUsage
};
}

if ('content_block' in resp) {
const { content_block: cb } =
if (resp.type === 'content_block_start') {
const { content_block: contentBlock } =
resp as unknown as AxAnthropicContentBlockStartEvent;
results = [{ content: cb.text }];

if (contentBlock.type === 'text') {
return {
results: [{ content: contentBlock.text }]
};
}
if (contentBlock.type === 'tool_use') {
if (
typeof contentBlock.id === 'string' &&
typeof resp.index === 'number' &&
!sstate.indexIdMap[resp.index]
) {
sstate.indexIdMap[resp.index] = contentBlock.id;
}
}
}

if (
'delta' in resp &&
'text' in (resp as unknown as AxAnthropicContentBlockDeltaEvent).delta
) {
const { delta: cb } =
resp as unknown as AxAnthropicContentBlockDeltaEvent;
results = [{ content: cb.text }];
if (resp.type === 'content_block_delta') {
const { delta } = resp as unknown as AxAnthropicContentBlockDeltaEvent;
if (delta.type === 'text_delta') {
return {
results: [{ content: delta.text }]
};
}
if (delta.type === 'input_json_delta') {
const id = sstate.indexIdMap[resp.index];
if (!id) {
throw new Error('invalid streaming index no id found: ' + resp.index);
}
const functionCalls = [
{
id,
type: 'function' as const,
function: {
name: '',
arguments: delta.partial_json
}
}
];
return {
results: [{ functionCalls }]
};
}
}

if (
'delta' in resp &&
'stop_reason' in (resp as unknown as AxAnthropicMessageDeltaEvent).delta
) {
const { delta } = resp as unknown as AxAnthropicMessageDeltaEvent;
results = [
{ content: '', finishReason: mapFinishReason(delta.stop_reason) }
];
modelUsage = {
promptTokens: resp.usage?.input_tokens ?? 0,
completionTokens: resp.usage?.output_tokens ?? 0,
totalTokens:
(resp.usage?.input_tokens ?? 0) + (resp.usage?.output_tokens ?? 0)
if (resp.type === 'message_delta') {
const { delta, usage } = resp as unknown as AxAnthropicMessageDeltaEvent;
return {
results: [
{
content: '',
finishReason: mapFinishReason(delta.stop_reason)
}
],
modelUsage: {
promptTokens: 0,
completionTokens: usage.output_tokens,
totalTokens: usage.output_tokens
}
};
}

return {
results,
modelUsage
results: [{ content: '' }]
};
};
}

function createMessages(
req: Readonly<AxChatRequest>
): AxAnthropicChatRequest['messages'] {
return req.chatPrompt.map((msg) => {
switch (msg.role) {
case 'function':
return {
role: 'user' as const,
content: [
{
type: 'tool_result',
text: msg.content,
tool_use_id: msg.functionId
}
]
};
case 'user': {
if (typeof msg.content === 'string') {
return { role: 'user' as const, content: msg.content };
}
const content = msg.content.map((v) => {
switch (v.type) {
case 'text':
return { type: 'text' as const, text: v.text };
case 'image':
return {
type: 'image' as const,
source: {
type: 'base64' as const,
media_type: v.mimeType,
data: v.image
}
};
default:
throw new Error('Invalid content type');
}
});
return {
role: 'user' as const,
content
};
}
case 'assistant': {
if (typeof msg.content === 'string') {
return { role: 'assistant' as const, content: msg.content };
}
if (typeof msg.functionCalls !== 'undefined') {
const content = msg.functionCalls.map((v) => {
let input;
if (typeof v.function.arguments === 'string') {
input = JSON.parse(v.function.arguments);
} else if (typeof v.function.arguments === 'object') {
input = v.function.arguments;
}
return {
type: 'tool_use' as const,
id: v.id,
name: v.function.name,
input
};
});
return {
role: 'assistant' as const,
content
};
}
throw new Error('Invalid content type');
}
default:
throw new Error('Invalid role');
}
});
}

function mapFinishReason(
stopReason?: AxAnthropicChatResponse['stop_reason'] | null
): AxChatResponse['results'][0]['finishReason'] | undefined {
Expand Down
Loading

0 comments on commit 95a0680

Please sign in to comment.