Skip to content

Commit

Permalink
feat(playground): parse model name and infer provider form span (#5021)
Browse files Browse the repository at this point in the history
* feat(playground): parse model name and infer provider form span

* update azure model selector to be a text field to account for user defined deployment names

* add label

* move defaults into generative constants, fallback to openai not azure

* update defult in test
  • Loading branch information
Parker-Stafford authored Oct 16, 2024
1 parent 9d375e5 commit 45973b7
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 34 deletions.
7 changes: 7 additions & 0 deletions app/src/constants/generativeConstants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ export const ModelProviders: Record<ModelProvider, string> = {
AZURE_OPENAI: "Azure OpenAI",
ANTHROPIC: "Anthropic",
};

/**
* The default model provider
*/
export const DEFAULT_MODEL_PROVIDER: ModelProvider = "OPENAI";

export const DEFAULT_CHAT_ROLE: ChatMessageRole = "user";
44 changes: 30 additions & 14 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import React, {
ReactNode,
startTransition,
Suspense,
useCallback,
useState,
} from "react";
import { graphql, useLazyLoadQuery } from "react-relay";
Expand All @@ -14,6 +15,7 @@ import {
Flex,
Form,
Text,
TextField,
View,
} from "@arizeai/components";

Expand Down Expand Up @@ -95,6 +97,20 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
`,
{ providerKey: instance.model.provider }
);

const onModelNameChange = useCallback(
(modelName: string) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider: instance.model.provider,
modelName,
},
});
},
[instance.model.provider, playgroundInstanceId, updateModel]
);

return (
<View padding="size-200">
<Form>
Expand All @@ -111,20 +127,20 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
});
}}
/>
<ModelPicker
modelName={instance.model.modelName}
provider={instance.model.provider}
query={query}
onChange={(modelName) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider: instance.model.provider,
modelName,
},
});
}}
/>
{instance.model.provider === "AZURE_OPENAI" ? (
<TextField
label="Deployment Name"
value={instance.model.modelName ?? ""}
onChange={onModelNameChange}
/>
) : (
<ModelPicker
modelName={instance.model.modelName}
provider={instance.model.provider}
query={query}
onChange={onModelNameChange}
/>
)}
</Form>
</View>
);
Expand Down
121 changes: 119 additions & 2 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import { DEFAULT_MODEL_PROVIDER } from "@phoenix/constants/generativeConstants";
import {
_resetInstanceId,
_resetMessageId,
PlaygroundInstance,
} from "@phoenix/store";

import {
getChatRole,
INPUT_MESSAGES_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
SPAN_ATTRIBUTES_PARSING_ERROR,
} from "../constants";
import {
getChatRole,
getModelProviderFromModelName,
transformSpanAttributesToPlaygroundInstance,
} from "../playgroundUtils";

Expand All @@ -24,7 +29,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
isRunning: false,
model: {
provider: "OPENAI",
modelName: "gpt-4o",
modelName: "gpt-3.5-turbo",
},
input: { variableKeys: [], variablesValueCache: {} },
tools: [],
Expand Down Expand Up @@ -70,6 +75,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: "OPENAI",
modelName: "gpt-4o",
},
template: defaultTemplate,
output: undefined,
},
Expand All @@ -85,6 +94,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: "OPENAI",
modelName: "gpt-4o",
},
template: defaultTemplate,

output: undefined,
Expand All @@ -93,6 +106,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
INPUT_MESSAGES_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
],
});
});
Expand Down Expand Up @@ -138,6 +152,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,

output: "This is an AI Answer",
},
parsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR],
Expand All @@ -160,6 +175,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
...basePlaygroundSpan,
attributes: JSON.stringify({
llm: {
model_name: "gpt-4o",
input_messages: [
{
message: {
Expand All @@ -182,6 +198,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: "OPENAI",
modelName: "gpt-4o",
},
template: {
__type: "chat",
messages: [
Expand All @@ -197,6 +217,84 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
parsingErrors: [],
});
});

it("should correctly parse the model name and infer the provider", () => {
const openAiAttributes = JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
model_name: "gpt-3.5-turbo",
},
});
const anthropicAttributes = JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
model_name: "claude-3-5-sonnet-20240620",
},
});
const unknownAttributes = JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
model_name: "test-my-deployment",
},
});

expect(
transformSpanAttributesToPlaygroundInstance({
...basePlaygroundSpan,
attributes: openAiAttributes,
})
).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
},
},
parsingErrors: [],
});

_resetMessageId();
_resetInstanceId();

expect(
transformSpanAttributesToPlaygroundInstance({
...basePlaygroundSpan,
attributes: anthropicAttributes,
})
).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: "ANTHROPIC",
modelName: "claude-3-5-sonnet-20240620",
},
},
parsingErrors: [],
});

_resetMessageId();
_resetInstanceId();

expect(
transformSpanAttributesToPlaygroundInstance({
...basePlaygroundSpan,
attributes: unknownAttributes,
})
).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
provider: DEFAULT_MODEL_PROVIDER,
modelName: "test-my-deployment",
},
},
parsingErrors: [],
});
});
});

describe("getChatRole", () => {
Expand All @@ -215,3 +313,22 @@ describe("getChatRole", () => {
expect(getChatRole("invalid")).toEqual("user");
});
});

describe("getModelProviderFromModelName", () => {
it("should return OPENAI if the model name includes 'gpt' or 'o1'", () => {
expect(getModelProviderFromModelName("gpt-3.5-turbo")).toEqual("OPENAI");
expect(getModelProviderFromModelName("o1")).toEqual("OPENAI");
});

it("should return ANTHROPIC if the model name includes 'claude'", () => {
expect(getModelProviderFromModelName("claude-3-5-sonnet-20240620")).toEqual(
"ANTHROPIC"
);
});

it(`should return ${DEFAULT_MODEL_PROVIDER} if the model name does not match any known models`, () => {
expect(getModelProviderFromModelName("test-my-model")).toEqual(
DEFAULT_MODEL_PROVIDER
);
});
});
22 changes: 20 additions & 2 deletions app/src/pages/playground/constants.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
export const NUM_MAX_PLAYGROUND_INSTANCES = 4;

export const DEFAULT_CHAT_ROLE = "user";

/**
* Map of {@link ChatMessageRole} to potential role values.
* Used to map roles to a canonical role.
Expand All @@ -12,3 +10,23 @@ export const ChatRoleMap: Record<ChatMessageRole, string[]> = {
system: ["system"],
tool: ["tool"],
};

/**
* Parsing errors for parsing a span to a playground instance
*/
export const INPUT_MESSAGES_PARSING_ERROR =
"Unable to parse span input messages, expected messages which include a role and content.";
export const OUTPUT_MESSAGES_PARSING_ERROR =
"Unable to parse span output messages, expected messages which include a role and content.";
export const OUTPUT_VALUE_PARSING_ERROR =
"Unable to parse span output expected output.value to be present.";
export const SPAN_ATTRIBUTES_PARSING_ERROR =
"Unable to parse span attributes, attributes must be valid JSON.";
export const MODEL_NAME_PARSING_ERROR =
"Unable to parse model name, expected llm.model_name to be present.";

export const modelProviderToModelPrefixMap: Record<ModelProvider, string[]> = {
AZURE_OPENAI: [],
ANTHROPIC: ["claude"],
OPENAI: ["gpt", "o1"],
};
Loading

0 comments on commit 45973b7

Please sign in to comment.