Skip to content

Commit

Permalink
feat(google-vertexai): Support Non-Google and Model Garden models in …
Browse files Browse the repository at this point in the history
…Vertex AI - Anthropic integration (#6999)

Co-authored-by: jacoblee93 <[email protected]>
Co-authored-by: bracesproul <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 14fa210 commit a1530da
Show file tree
Hide file tree
Showing 20 changed files with 2,957 additions and 674 deletions.
6 changes: 4 additions & 2 deletions docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"source": [
"# ChatVertexAI\n",
"\n",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.\n",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.",
"It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).",
"\n",
"\n",
"This will help you getting started with `ChatVertexAI` [chat models](/docs/concepts/chat_models). For detailed documentation of all `ChatVertexAI` features and configurations head to the [API reference](https://api.js.langchain.com/classes/langchain_google_vertexai.ChatVertexAI.html).\n",
"\n",
Expand Down Expand Up @@ -279,4 +281,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
13 changes: 12 additions & 1 deletion docs/core_docs/docs/integrations/platforms/google.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Functionality related to [Google Cloud Platform](https://cloud.google.com/)

### Gemini Models

Access Gemini models such as `gemini-pro` and `gemini-pro-vision` through the [`ChatGoogleGenerativeAI`](/docs/integrations/chat/google_generativeai),
Access Gemini models such as `gemini-1.5-pro` and `gemini-1.5-flex` through the [`ChatGoogleGenerativeAI`](/docs/integrations/chat/google_generativeai),
or if using VertexAI, via the [`ChatVertexAI`](/docs/integrations/chat/google_vertex_ai) class.

import Tabs from "@theme/Tabs";
Expand Down Expand Up @@ -153,6 +153,17 @@ Click [here](/docs/integrations/chat/google_vertex_ai) for the `@langchain/googl

The value of `image_url` must be a base64 encoded image (e.g., ``).

### Non-Gemini Models

See above for setting up authentication through Vertex AI to use these models.

[Anthropic](/docs/integrations/chat/anthropic) Claude models are also available through
the [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude)
platform. See [here](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude)
for more information about enabling access to the models and the model names to use.

PaLM models are no longer supported.

## Vector Store

### Vertex AI Vector Search
Expand Down
121 changes: 43 additions & 78 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ import {
GoogleAISafetySetting,
GoogleConnectionParams,
GooglePlatformType,
GeminiContent,
GeminiTool,
GoogleAIBaseLanguageModelCallOptions,
GoogleAIAPI,
GoogleAIAPIParams,
} from "./types.js";
import {
convertToGeminiTools,
copyAIModelParams,
copyAndValidateModelParamsInto,
} from "./utils/common.js";
import { AbstractGoogleLLMConnection } from "./connection.js";
import { DefaultGeminiSafetyHandler } from "./utils/gemini.js";
import { DefaultGeminiSafetyHandler, getGeminiAPI } from "./utils/gemini.js";
import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js";
import { JsonStream } from "./utils/stream.js";
import { ensureParams } from "./utils/failed_handler.js";
Expand Down Expand Up @@ -96,71 +97,21 @@ export class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
return true;
}

async formatContents(
input: BaseMessage[],
_parameters: GoogleAIModelParams
): Promise<GeminiContent[]> {
const inputPromises: Promise<GeminiContent[]>[] = input.map((msg, i) =>
this.api.baseMessageToContent(
msg,
input[i - 1],
this.useSystemInstruction
)
);
const inputs = await Promise.all(inputPromises);

return inputs.reduce((acc, cur) => {
// Filter out the system content
if (cur.every((content) => content.role === "system")) {
return acc;
}

// Combine adjacent function messages
if (
cur[0]?.role === "function" &&
acc.length > 0 &&
acc[acc.length - 1].role === "function"
) {
acc[acc.length - 1].parts = [
...acc[acc.length - 1].parts,
...cur[0].parts,
];
} else {
acc.push(...cur);
}

return acc;
}, [] as GeminiContent[]);
buildGeminiAPI(): GoogleAIAPI {
const geminiConfig: GeminiAPIConfig = {
useSystemInstruction: this.useSystemInstruction,
...(this.apiConfig as GeminiAPIConfig),
};
return getGeminiAPI(geminiConfig);
}

async formatSystemInstruction(
input: BaseMessage[],
_parameters: GoogleAIModelParams
): Promise<GeminiContent> {
if (!this.useSystemInstruction) {
return {} as GeminiContent;
get api(): GoogleAIAPI {
switch (this.apiName) {
case "google":
return this.buildGeminiAPI();
default:
return super.api;
}

let ret = {} as GeminiContent;
for (let index = 0; index < input.length; index += 1) {
const message = input[index];
if (message._getType() === "system") {
// For system types, we only want it if it is the first message,
// if it appears anywhere else, it should be an error.
if (index === 0) {
// eslint-disable-next-line prefer-destructuring
ret = (
await this.api.baseMessageToContent(message, undefined, true)
)[0];
} else {
throw new Error(
"System messages are only permitted as the first passed message."
);
}
}
}

return ret;
}
}

Expand All @@ -172,7 +123,7 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleConnectionParams<AuthOptions>,
GoogleAIModelParams,
GoogleAISafetyParams,
GeminiAPIConfig,
GoogleAIAPIParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

/**
Expand Down Expand Up @@ -341,13 +292,14 @@ export abstract class ChatGoogleBase<AuthOptions>
const response = await this.connection.request(
messages,
parameters,
options
options,
runManager
);
const ret = this.connection.api.safeResponseToChatResult(
response,
this.safetyHandler
);
await runManager?.handleLLMNewToken(ret.generations[0].text);
const ret = this.connection.api.responseToChatResult(response);
const chunk = ret?.generations?.[0];
if (chunk) {
await runManager?.handleLLMNewToken(chunk.text || "");
}
return ret;
}

Expand All @@ -361,7 +313,8 @@ export abstract class ChatGoogleBase<AuthOptions>
const response = await this.streamedConnection.request(
_messages,
parameters,
options
options,
runManager
);

// Get the streaming parser of the response
Expand All @@ -372,6 +325,12 @@ export abstract class ChatGoogleBase<AuthOptions>
// that is either available or added to the queue
while (!stream.streamDone) {
const output = await stream.nextChunk();
await runManager?.handleCustomEvent(
`google-chunk-${this.constructor.name}`,
{
output,
}
);
if (
output &&
output.usageMetadata &&
Expand All @@ -386,10 +345,7 @@ export abstract class ChatGoogleBase<AuthOptions>
}
const chunk =
output !== null
? this.connection.api.safeResponseToChatGeneration(
{ data: output },
this.safetyHandler
)
? this.connection.api.responseToChatGeneration({ data: output })
: new ChatGenerationChunk({
text: "",
generationInfo: { finishReason: "stop" },
Expand All @@ -398,8 +354,17 @@ export abstract class ChatGoogleBase<AuthOptions>
usage_metadata: usageMetadata,
}),
});
yield chunk;
await runManager?.handleLLMNewToken(chunk.text);
if (chunk) {
yield chunk;
await runManager?.handleLLMNewToken(
chunk.text ?? "",
undefined,
undefined,
undefined,
undefined,
{ chunk }
);
}
}
}

Expand Down
Loading

0 comments on commit a1530da

Please sign in to comment.