Skip to content

Commit

Permalink
fix: Hide tool button on non-supported providers, Display errors for …
Browse files Browse the repository at this point in the history
…missing invocation params (#5470)

* fix: Hide tool button on providers that do not have tool support implemented

* refactor: Hoist invocation param fetching to top level of instance, cache results in store

Additionally moves invocation param filtering to just before making the completion queries,
not when the invocation param form is opened

* feat: Display error messages and tooltips when required invocation param fields are not provided

* fix test

* Remove redundant gql query in PlaygroundChatTemplateFooter

* lint

* Add comments

* Merge default model parameter values with playground instance values

* Use danger color alias for invocation param field errors

* Rename variable for clarity

* Simplify invocation parameter merging logic

* Add comments describing reasoning for query field name aliases
  • Loading branch information
cephalization authored Nov 25, 2024
1 parent 326830c commit baecc54
Show file tree
Hide file tree
Showing 15 changed files with 617 additions and 292 deletions.
11 changes: 7 additions & 4 deletions app/src/components/Loading.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import React from "react";
import React, { ComponentProps } from "react";
import { css } from "@emotion/react";

import { ProgressCircle, Text } from "@arizeai/components";

type LoadingProps = { message?: string };
export const Loading = ({ message }: LoadingProps) => {
type LoadingProps = {
message?: string;
size?: ComponentProps<typeof ProgressCircle>["size"];
};
export const Loading = ({ message, size }: LoadingProps) => {
return (
<div
css={css`
Expand All @@ -17,7 +20,7 @@ export const Loading = ({ message }: LoadingProps) => {
gap: var(--px-spacing-med);
`}
>
<ProgressCircle isIndeterminate aria-label="loading" />
<ProgressCircle isIndeterminate aria-label="loading" size={size} />
{message != null ? <Text>{message}</Text> : null}
</div>
);
Expand Down
36 changes: 7 additions & 29 deletions app/src/pages/playground/InvocationParametersFormFields.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useEffect } from "react";
import React, { useCallback } from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import { Slider, Switch, TextField } from "@arizeai/components";
Expand All @@ -13,11 +13,7 @@ import {
import { InvocationParameterInput } from "./__generated__/PlaygroundOutputSubscription.graphql";
import { paramsToIgnoreInInvocationParametersForm } from "./constants";
import { InvocationParameterJsonEditor } from "./InvocationParameterJsonEditor";
import {
areInvocationParamsEqual,
constrainInvocationParameterInputsToDefinition,
toCamelCase,
} from "./playgroundUtils";
import { areInvocationParamsEqual, toCamelCase } from "./playgroundUtils";

export type InvocationParameter = Mutable<
InvocationParametersFormFieldsQuery$data["modelInvocationParameters"]
Expand Down Expand Up @@ -205,9 +201,7 @@ export const InvocationParametersFormFields = ({
const updateInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.updateInstanceModelInvocationParameters
);
const filterInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.filterInstanceModelInvocationParameters
);

/**
* Azure openai has user defined model names but our invocation parameters query will never know
* what they are. We will just pass in an empty model name and the query will fallback to the set
Expand All @@ -227,6 +221,10 @@ export const InvocationParametersFormFields = ({
required
canonicalName
}
# defaultValue must be aliased because Relay will not create a union type for fields with the same name
# follow the naming convention of the field type e.g. floatDefaultValue for FloatInvocationParameter
# default value mapping elsewhere in playground code relies on this naming convention
# https://github.com/facebook/relay/issues/3776
... on BoundedFloatInvocationParameter {
minValue
maxValue
Expand Down Expand Up @@ -306,26 +304,6 @@ export const InvocationParametersFormFields = ({
[instance, updateInstanceModelInvocationParameters]
);

useEffect(() => {
// filter invocation parameters to only include those that are supported by the model
// This will remove configured values that are not supported by the newly selected model
// Including invocation parameters managed outside of this form, like response_format
if (modelInvocationParameters) {
filterInstanceModelInvocationParameters({
instanceId: instance.id,
modelSupportedInvocationParameters:
modelInvocationParameters as Mutable<
typeof modelInvocationParameters
>,
filter: constrainInvocationParameterInputsToDefinition,
});
}
}, [
filterInstanceModelInvocationParameters,
instance.id,
modelInvocationParameters,
]);

// It is safe to render this component if the model name is not set for non-azure models
// Hooks will still run to filter invocation parameters to only include those supported by the model
// but no form fields will be rendered if the model name is not set
Expand Down
41 changes: 41 additions & 0 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
TextField,
Tooltip,
TooltipTrigger,
TriggerWrap,
} from "@arizeai/components";

import {
Expand All @@ -40,6 +41,7 @@ import { InvocationParametersFormFields } from "./InvocationParametersFormFields
import { ModelPicker } from "./ModelPicker";
import { ModelProviderPicker } from "./ModelProviderPicker";
import {
areRequiredInvocationParametersConfigured,
convertInstanceToolsToProvider,
convertMessageToolCallsToProvider,
} from "./playgroundUtils";
Expand Down Expand Up @@ -165,6 +167,16 @@ export function ModelConfigButton(props: ModelConfigButtonProps) {
`Playground instance ${props.playgroundInstanceId} not found`
);
}

const modelSupportedInvocationParameters =
instance.model.supportedInvocationParameters;
const configuredInvocationParameters = instance.model.invocationParameters;
const requiredInvocationParametersConfigured =
areRequiredInvocationParametersConfigured(
configuredInvocationParameters,
modelSupportedInvocationParameters
);

return (
<Fragment>
<Button
Expand All @@ -191,6 +203,18 @@ export function ModelConfigButton(props: ModelConfigButtonProps) {
>
<Text>{instance.model.modelName || "--"}</Text>
</div>
{!requiredInvocationParametersConfigured ? (
<TooltipTrigger delay={0} offset={5}>
<span>
<TriggerWrap>
<Icon color="danger" svg={<Icons.InfoOutline />} />
</TriggerWrap>
</span>
<Tooltip>
Some required invocation parameters are not configured.
</Tooltip>
</TooltipTrigger>
) : null}
</Flex>
</Button>
<DialogContainer
Expand Down Expand Up @@ -282,6 +306,15 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
const updateInstance = usePlaygroundContext((state) => state.updateInstance);
const updateModel = usePlaygroundContext((state) => state.updateModel);

const modelSupportedInvocationParameters =
instance.model.supportedInvocationParameters;
const configuredInvocationParameters = instance.model.invocationParameters;
const requiredInvocationParametersConfigured =
areRequiredInvocationParametersConfigured(
configuredInvocationParameters,
modelSupportedInvocationParameters
);

const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery($providerKey: GenerativeProviderKey!) {
Expand Down Expand Up @@ -365,6 +398,14 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {

return (
<form css={modelConfigFormCSS}>
{!requiredInvocationParametersConfigured ? (
<Flex direction="row" gap="size-100">
<Icon color="danger" svg={<Icons.InfoOutline />} />
<Text color="danger">
Some required invocation parameters are not configured.
</Text>
</Flex>
) : null}
<ModelProviderPicker
provider={instance.model.provider}
query={query}
Expand Down
97 changes: 97 additions & 0 deletions app/src/pages/playground/ModelSupportedParamsFetcher.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { useEffect } from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { Mutable } from "@phoenix/typeUtils";

import { ModelSupportedParamsFetcherQuery } from "./__generated__/ModelSupportedParamsFetcherQuery.graphql";

/**
* Fetches the supported invocation parameters for a model and syncs them to the
* playground store instance.
*/
export const ModelSupportedParamsFetcher = ({
instanceId,
}: {
instanceId: number;
}) => {
const modelProvider = usePlaygroundContext(
(state) =>
state.instances.find((instance) => instance.id === instanceId)?.model
.provider
);
const modelName = usePlaygroundContext(
(state) =>
state.instances.find((instance) => instance.id === instanceId)?.model
.modelName
);
const updateModelSupportedInvocationParameters = usePlaygroundContext(
(state) => state.updateModelSupportedInvocationParameters
);
const { modelInvocationParameters } =
useLazyLoadQuery<ModelSupportedParamsFetcherQuery>(
graphql`
query ModelSupportedParamsFetcherQuery($input: ModelsInput!) {
modelInvocationParameters(input: $input) {
__typename
... on InvocationParameterBase {
invocationName
canonicalName
required
}
# defaultValue must be aliased because Relay will not create a union type for fields with the same name
# follow the naming convention of the field type e.g. floatDefaultValue for FloatInvocationParameter
# default value mapping elsewhere in playground code relies on this naming convention
# https://github.com/facebook/relay/issues/3776
... on BooleanInvocationParameter {
booleanDefaultValue: defaultValue
invocationInputField
}
... on BoundedFloatInvocationParameter {
floatDefaultValue: defaultValue
invocationInputField
}
... on FloatInvocationParameter {
floatDefaultValue: defaultValue
invocationInputField
}
... on IntInvocationParameter {
intDefaultValue: defaultValue
invocationInputField
}
... on JSONInvocationParameter {
jsonDefaultValue: defaultValue
invocationInputField
}
... on StringInvocationParameter {
stringDefaultValue: defaultValue
invocationInputField
}
... on StringListInvocationParameter {
stringListDefaultValue: defaultValue
invocationInputField
}
}
}
`,
{
input: {
providerKey: modelProvider,
modelName,
},
}
);
useEffect(() => {
updateModelSupportedInvocationParameters({
instanceId,
supportedInvocationParameters: modelInvocationParameters as Mutable<
typeof modelInvocationParameters
>,
});
}, [
modelInvocationParameters,
instanceId,
updateModelSupportedInvocationParameters,
]);
return null;
};
2 changes: 1 addition & 1 deletion app/src/pages/playground/Playground.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ function PlaygroundContent() {
<Flex direction="row" gap="size-200" maxWidth="100%">
{instances.map((instance) => (
<View
flex="1 1 0px"
key={instance.id}
minWidth={PLAYGROUND_PROMPT_PANEL_MIN_WIDTH}
flex="1 1 0px"
>
<PlaygroundTemplate
key={instance.id}
Expand Down
35 changes: 16 additions & 19 deletions app/src/pages/playground/PlaygroundChatTemplate.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import React, {
PropsWithChildren,
Suspense,
useCallback,
useState,
} from "react";
import React, { PropsWithChildren, useCallback, useState } from "react";
import {
DndContext,
KeyboardSensor,
Expand Down Expand Up @@ -48,20 +43,21 @@ import { assertUnreachable } from "@phoenix/typeUtils";
import { safelyParseJSON } from "@phoenix/utils/jsonUtils";

import { ChatMessageToolCallsEditor } from "./ChatMessageToolCallsEditor";
import { RESPONSE_FORMAT_PARAM_CANONICAL_NAME } from "./constants";
import {
RESPONSE_FORMAT_PARAM_CANONICAL_NAME,
RESPONSE_FORMAT_PARAM_NAME,
} from "./constants";
import {
AIMessageContentRadioGroup,
AIMessageMode,
MessageMode,
} from "./MessageContentRadioGroup";
import { MessageRolePicker } from "./MessageRolePicker";
import {
PlaygroundChatTemplateFooter,
PlaygroundChatTemplateFooterFallback,
} from "./PlaygroundChatTemplateFooter";
import { PlaygroundChatTemplateFooter } from "./PlaygroundChatTemplateFooter";
import { PlaygroundResponseFormat } from "./PlaygroundResponseFormat";
import { PlaygroundTools } from "./PlaygroundTools";
import {
areInvocationParamsEqual,
convertMessageToolCallsToProvider,
createToolCallForProvider,
normalizeMessageAttributeValue,
Expand Down Expand Up @@ -94,8 +90,11 @@ export function PlaygroundChatTemplate(props: PlaygroundChatTemplateProps) {

const hasTools = playgroundInstance.tools.length > 0;
const hasResponseFormat =
playgroundInstance.model.invocationParameters.find(
(p) => p.canonicalName === RESPONSE_FORMAT_PARAM_CANONICAL_NAME
playgroundInstance.model.invocationParameters.find((p) =>
areInvocationParamsEqual(p, {
canonicalName: RESPONSE_FORMAT_PARAM_CANONICAL_NAME,
invocationName: RESPONSE_FORMAT_PARAM_NAME,
})
) != null;
const { template } = playgroundInstance;
if (template.__type !== "chat") {
Expand Down Expand Up @@ -171,12 +170,10 @@ export function PlaygroundChatTemplate(props: PlaygroundChatTemplateProps) {
borderTopWidth="thin"
borderBottomWidth={hasTools || hasResponseFormat ? "thin" : undefined}
>
<Suspense fallback={<PlaygroundChatTemplateFooterFallback />}>
<PlaygroundChatTemplateFooter
instanceId={id}
hasResponseFormat={hasResponseFormat}
/>
</Suspense>
<PlaygroundChatTemplateFooter
instanceId={id}
hasResponseFormat={hasResponseFormat}
/>
</View>
{hasTools ? <PlaygroundTools {...props} /> : null}
{hasResponseFormat ? <PlaygroundResponseFormat {...props} /> : null}
Expand Down
Loading

0 comments on commit baecc54

Please sign in to comment.