From 4ebad972fc7e9bef96ba20b707ab836aa2f2b73f Mon Sep 17 00:00:00 2001 From: dosco <832235+dosco@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:50:53 -0700 Subject: [PATCH] fix: gemini batch embed endpoint --- src/ax/ai/google-gemini/api.ts | 11 +++++++---- src/ax/ai/google-gemini/types.ts | 6 ++++-- src/ax/dsp/sig.ts | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/ax/ai/google-gemini/api.ts b/src/ax/ai/google-gemini/api.ts index 0c88a98..6c3d6c0 100644 --- a/src/ax/ai/google-gemini/api.ts +++ b/src/ax/ai/google-gemini/api.ts @@ -56,7 +56,7 @@ const safetySettings: AxAIGoogleGeminiSafetySettings = [ export const axAIGoogleGeminiDefaultConfig = (): AxAIGoogleGeminiConfig => structuredClone({ model: AxAIGoogleGeminiModel.Gemini15Pro, - embedModel: AxAIGoogleGeminiEmbedModel.Embedding001, + embedModel: AxAIGoogleGeminiEmbedModel.TextEmbedding004, safetySettings, ...axBaseAIDefaultConfig() }); @@ -65,7 +65,7 @@ export const axAIGoogleGeminiDefaultCreativeConfig = (): AxAIGoogleGeminiConfig => structuredClone({ model: AxAIGoogleGeminiModel.Gemini15Flash, - embedModel: AxAIGoogleGeminiEmbedModel.Embedding001, + embedModel: AxAIGoogleGeminiEmbedModel.TextEmbedding004, safetySettings, ...axBaseAIDefaultCreativeConfig() }); @@ -345,11 +345,14 @@ export class AxAIGoogleGemini extends AxBaseAI< } const apiConfig = { - name: `/models/${model}:batchEmbedText?key=${this.apiKey}` + name: `/models/${model}:batchEmbedContents?key=${this.apiKey}` }; const reqValue: AxAIGoogleGeminiBatchEmbedRequest = { - requests: req.texts.map((text) => ({ model, text })) + requests: req.texts.map((text) => ({ + model: 'models/' + model, + content: { parts: [{ text }] } + })) }; return [apiConfig, reqValue]; diff --git a/src/ax/ai/google-gemini/types.ts b/src/ax/ai/google-gemini/types.ts index e7bde80..0832870 100644 --- a/src/ax/ai/google-gemini/types.ts +++ b/src/ax/ai/google-gemini/types.ts @@ -10,7 +10,7 @@ export enum AxAIGoogleGeminiModel { } export enum AxAIGoogleGeminiEmbedModel { - Embedding001 = 'embedding-001' + TextEmbedding004 = 'text-embedding-004' } export enum AxAIGoogleGeminiSafetyCategory { @@ -159,7 +159,9 @@ export type AxAIGoogleGeminiConfig = AxModelConfig & { export type AxAIGoogleGeminiBatchEmbedRequest = { requests: { model: string; - text: string; + content: { + parts: { text: string }[]; + }; }[]; }; diff --git a/src/ax/dsp/sig.ts b/src/ax/dsp/sig.ts index e65edc2..e71e8ea 100644 --- a/src/ax/dsp/sig.ts +++ b/src/ax/dsp/sig.ts @@ -3,10 +3,10 @@ import { createHash } from 'crypto'; import type { AxFunctionJSONSchema } from '../ai/types.js'; import { - parseSignature, type InputParsedField, type OutputParsedField, - type ParsedSignature + type ParsedSignature, + parseSignature } from './parser.js'; export interface AxField {