Skip to content

Commit

Permalink
fix(playground): plumb through model name and providers (#4999)
Browse files Browse the repository at this point in the history
Co-authored-by: Mikyo King <[email protected]>
  • Loading branch information
axiomofjoy and mikeldking authored Oct 15, 2024
1 parent 23e9d82 commit 23958bd
Show file tree
Hide file tree
Showing 14 changed files with 327 additions and 144 deletions.
25 changes: 21 additions & 4 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ union Bin = NominalBin | IntervalBin | MissingValueBin

input ChatCompletionInput {
messages: [ChatCompletionMessageInput!]!
model: GenerativeModelInput!
}

input ChatCompletionMessageInput {
Expand Down Expand Up @@ -822,6 +823,22 @@ type Functionality {
tracing: Boolean!
}

input GenerativeModelInput {
providerKey: GenerativeProviderKey!
name: String!
}

type GenerativeProvider {
name: String!
key: GenerativeProviderKey!
}

enum GenerativeProviderKey {
OPENAI
ANTHROPIC
AZURE_OPENAI
}

"""
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
"""
Expand Down Expand Up @@ -926,9 +943,8 @@ type Model {
): PerformanceTimeSeries!
}

type ModelProvider {
name: String!
modelNames: [String!]!
input ModelNamesInput {
providerKey: GenerativeProviderKey!
}

type Mutation {
Expand Down Expand Up @@ -1123,7 +1139,8 @@ type PromptResponse {
}

type Query {
modelProviders(vendors: [String!]!): [ModelProvider!]!
modelProviders: [GenerativeProvider!]!
modelNames(input: ModelNamesInput!): [String!]!
users(first: Int = 50, last: Int, after: String, before: String): UserConnection!
userRoles: [UserRole!]!
userApiKeys: [UserApiKey!]!
Expand Down
19 changes: 10 additions & 9 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ export function ModelConfigButton(props: ModelConfigButtonProps) {

interface ModelConfigDialogContentProps extends ModelConfigButtonProps {}
function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery {
...ModelPickerFragment
}
`,
{}
);
const { playgroundInstanceId } = props;
const updateModel = usePlaygroundContext((state) => state.updateModel);
const instance = usePlaygroundContext((state) =>
Expand All @@ -94,12 +86,21 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
`Playground instance ${props.playgroundInstanceId} not found`
);
}

const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery($providerKey: GenerativeProviderKey!) {
...ModelProviderPickerFragment
...ModelPickerFragment @arguments(providerKey: $providerKey)
}
`,
{ providerKey: instance.model.provider }
);
return (
<View padding="size-200">
<Form>
<ModelProviderPicker
provider={instance.model.provider}
query={query}
onChange={(provider) => {
updateModel({
instanceId: playgroundInstanceId,
Expand Down
21 changes: 7 additions & 14 deletions app/src/pages/playground/ModelPicker.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useMemo } from "react";
import React from "react";
import { graphql, useFragment } from "react-relay";

import { Item, Picker, PickerProps } from "@arizeai/components";
Expand All @@ -18,22 +18,15 @@ type ModelPickerProps = {
export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
const data = useFragment<ModelPickerFragment$key>(
graphql`
fragment ModelPickerFragment on Query {
modelProviders(vendors: ["OpenAI", "Anthropic"]) {
name
modelNames
}
fragment ModelPickerFragment on Query
@argumentDefinitions(
providerKey: { type: "GenerativeProviderKey!", defaultValue: OPENAI }
) {
modelNames(input: { providerKey: $providerKey })
}
`,
query
);
const modelNames = useMemo(() => {
// TODO: Lowercase is not enough for things like Azure OpenAI
const provider = data.modelProviders.find(
(provider) => provider.name.toLowerCase() === props.provider.toLowerCase()
);
return provider?.modelNames ?? [];
}, [data, props.provider]);
return (
<Picker
label={"Model"}
Expand All @@ -49,7 +42,7 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
width={"100%"}
{...props}
>
{modelNames.map((modelName) => {
{data.modelNames.map((modelName) => {
return <Item key={modelName}>{modelName}</Item>;
})}
</Picker>
Expand Down
21 changes: 18 additions & 3 deletions app/src/pages/playground/ModelProviderPicker.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import React from "react";
import { graphql, useFragment } from "react-relay";

import { Item, Picker, PickerProps } from "@arizeai/components";

import { ModelProviders } from "@phoenix/constants/generativeConstants";
import { isModelProvider } from "@phoenix/utils/generativeUtils";

import type { ModelProviderPickerFragment$key } from "./__generated__/ModelProviderPickerFragment.graphql";

type ModelProviderPickerProps = {
onChange: (provider: ModelProvider) => void;
query: ModelProviderPickerFragment$key;
provider?: ModelProvider;
} & Omit<
PickerProps<ModelProvider>,
Expand All @@ -15,8 +18,20 @@ type ModelProviderPickerProps = {

export function ModelProviderPicker({
onChange,
query,
...props
}: ModelProviderPickerProps) {
const data = useFragment<ModelProviderPickerFragment$key>(
graphql`
fragment ModelProviderPickerFragment on Query {
modelProviders {
key
name
}
}
`,
query
);
return (
<Picker
label={"Provider"}
Expand All @@ -33,8 +48,8 @@ export function ModelProviderPicker({
width={"100%"}
{...props}
>
{Object.entries(ModelProviders).map(([key, value]) => {
return <Item key={key}>{value}</Item>;
{data.modelProviders.map((provider) => {
return <Item key={provider.key}>{provider.name}</Item>;
})}
</Picker>
);
Expand Down
2 changes: 1 addition & 1 deletion app/src/pages/playground/PlaygroundCredentialsDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function PlaygroundCredentialsDropdown() {
<Heading level={2} weight="heavy">
API Keys
</Heading>
<Text color="white70">
<Text color="text-700">
API keys are stored in your browser and used to communicate with
their respective API&apos;s.
</Text>
Expand Down
7 changes: 6 additions & 1 deletion app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ function useChatCompletionSubscription({
subscription: graphql`
subscription PlaygroundOutputSubscription(
$messages: [ChatCompletionMessageInput!]!
$model: GenerativeModelInput!
) {
chatCompletion(input: { messages: $messages })
chatCompletion(input: { messages: $messages, model: $model })
}
`,
variables: params,
Expand Down Expand Up @@ -177,6 +178,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
useChatCompletionSubscription({
params: {
messages: instance.template.messages.map(toGqlChatCompletionMessage),
model: {
providerKey: instance.model.provider,
name: instance.model.modelName || "",
},
},
runId: instance.activeRunId,
onNext: (response) => {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 23958bd

Please sign in to comment.