Skip to content

Commit

Permalink
feat(playground): plumb through and apply template variables (#5052)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 17, 2024
1 parent 95748e0 commit d0b1641
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 28 deletions.
11 changes: 11 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ input ChatCompletionInput {
model: GenerativeModelInput!
invocationParameters: InvocationParameters!
tools: [JSON!]
template: TemplateOptions
apiKey: String = null
}

Expand Down Expand Up @@ -1459,6 +1460,16 @@ type SystemApiKey implements ApiKey & Node {
id: GlobalID!
}

enum TemplateLanguage {
MUSTACHE
F_STRING
}

input TemplateOptions {
variables: JSON!
language: TemplateLanguage!
}

type TextChunk {
content: String!
}
Expand Down
4 changes: 2 additions & 2 deletions app/src/components/templateEditor/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
* ```
*/
export const TemplateLanguages = {
FString: "f-string", // {variable}
Mustache: "mustache", // {{variable}}
FString: "F_STRING", // {variable}
Mustache: "MUSTACHE", // {{variable}}
} as const;
4 changes: 2 additions & 2 deletions app/src/pages/playground/PlaygroundInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ export function PlaygroundInput() {
if (variableKeys.length === 0) {
let templateSyntax = "";
switch (templateLanguage) {
case "f-string": {
case "F_STRING": {
templateSyntax = "{input name}";
break;
}
case "mustache": {
case "MUSTACHE": {
templateSyntax = "{{input name}}";
break;
}
Expand Down
16 changes: 15 additions & 1 deletion app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import { useCredentialsContext } from "@phoenix/contexts/CredentialsContext";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles";
import type { ToolCall } from "@phoenix/store";
import { ChatMessage, generateMessageId } from "@phoenix/store";
import {
ChatMessage,
generateMessageId,
selectDerivedInputVariables,
} from "@phoenix/store";
import { assertUnreachable } from "@phoenix/typeUtils";

import {
Expand Down Expand Up @@ -135,6 +139,7 @@ function useChatCompletionSubscription({
$model: GenerativeModelInput!
$invocationParameters: InvocationParameters!
$tools: [JSON!]
$templateOptions: TemplateOptions
$apiKey: String
) {
chatCompletion(
Expand All @@ -143,6 +148,7 @@ function useChatCompletionSubscription({
model: $model
invocationParameters: $invocationParameters
tools: $tools
template: $templateOptions
apiKey: $apiKey
}
) {
Expand Down Expand Up @@ -212,6 +218,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
const instance = instances.find(
(instance) => instance.id === props.playgroundInstanceId
);
const templateLanguage = usePlaygroundContext(
(state) => state.templateLanguage
);
const templateVariables = usePlaygroundContext(selectDerivedInputVariables);
const markPlaygroundInstanceComplete = usePlaygroundContext(
(state) => state.markPlaygroundInstanceComplete
);
Expand Down Expand Up @@ -239,6 +249,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
invocationParameters: {
toolChoice: instance.toolChoice,
},
templateOptions: {
variables: templateVariables,
language: templateLanguage,
},
tools: instance.tools.map((tool) => tool.definition),
apiKey: credentials[instance.model.provider],
},
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 71 additions & 15 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from dataclasses import fields
from datetime import datetime
from enum import Enum
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand All @@ -10,6 +11,7 @@
AsyncIterator,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -46,6 +48,11 @@
from phoenix.server.dml_event import SpanInsertEvent
from phoenix.trace.attributes import unflatten
from phoenix.utilities.json import jsonify
from phoenix.utilities.template_formatters import (
FStringTemplateFormatter,
MustacheTemplateFormatter,
TemplateFormatter,
)

if TYPE_CHECKING:
from openai.types.chat import (
Expand All @@ -57,6 +64,18 @@
ToolCallIndex: TypeAlias = int


@strawberry.enum
class TemplateLanguage(Enum):
MUSTACHE = "MUSTACHE"
F_STRING = "F_STRING"


@strawberry.input
class TemplateOptions:
variables: JSONScalarType
language: TemplateLanguage


@strawberry.type
class TextChunk:
content: str
Expand Down Expand Up @@ -91,42 +110,43 @@ class ChatCompletionInput:
model: GenerativeModelInput
invocation_parameters: InvocationParameters
tools: Optional[List[JSONScalarType]] = UNSET
template: Optional[TemplateOptions] = UNSET
api_key: Optional[str] = strawberry.field(default=None)


def to_openai_chat_completion_param(
message: ChatCompletionMessageInput,
role: ChatCompletionMessageRole, content: JSONScalarType
) -> "ChatCompletionMessageParam":
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)

if message.role is ChatCompletionMessageRole.USER:
if role is ChatCompletionMessageRole.USER:
return ChatCompletionUserMessageParam(
{
"content": message.content,
"content": content,
"role": "user",
}
)
if message.role is ChatCompletionMessageRole.SYSTEM:
if role is ChatCompletionMessageRole.SYSTEM:
return ChatCompletionSystemMessageParam(
{
"content": message.content,
"content": content,
"role": "system",
}
)
if message.role is ChatCompletionMessageRole.AI:
if role is ChatCompletionMessageRole.AI:
return ChatCompletionAssistantMessageParam(
{
"content": message.content,
"content": content,
"role": "assistant",
}
)
if message.role is ChatCompletionMessageRole.TOOL:
if role is ChatCompletionMessageRole.TOOL:
raise NotImplementedError
assert_never(message.role)
assert_never(role)


@strawberry.type
Expand All @@ -140,6 +160,13 @@ async def chat_completion(
client = AsyncOpenAI(api_key=input.api_key)
invocation_parameters = jsonify(input.invocation_parameters)

messages: List[Tuple[ChatCompletionMessageRole, str]] = [
(message.role, message.content) for message in input.messages
]
if template_options := input.template:
messages = list(_formatted_messages(messages, template_options))
openai_messages = [to_openai_chat_completion_param(*message) for message in messages]

in_memory_span_exporter = InMemorySpanExporter()
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(
Expand All @@ -154,7 +181,7 @@ async def chat_completion(
_llm_span_kind(),
_llm_model_name(input.model.name),
_llm_tools(input.tools or []),
_llm_input_messages(input.messages),
_llm_input_messages(messages),
_llm_invocation_parameters(invocation_parameters),
_input_value_and_mime_type(input),
)
Expand All @@ -165,7 +192,7 @@ async def chat_completion(
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
role: Optional[str] = None
async for chunk in await client.chat.completions.create(
messages=(to_openai_chat_completion_param(message) for message in input.messages),
messages=openai_messages,
model=input.model.name,
stream=True,
tools=input.tools or NOT_GIVEN,
Expand Down Expand Up @@ -291,10 +318,12 @@ def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))


def _llm_input_messages(messages: List[ChatCompletionMessageInput]) -> Iterator[Tuple[str, Any]]:
for i, message in enumerate(messages):
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", message.role.value.lower()
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", message.content
def _llm_input_messages(
messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
) -> Iterator[Tuple[str, Any]]:
for i, (role, content) in enumerate(messages):
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content


def _llm_output_messages(
Expand Down Expand Up @@ -332,6 +361,33 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime:
return datetime.fromtimestamp(epoch_seconds)


def _formatted_messages(
messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
"""
Formats the messages using the given template options.
"""
template_formatter = _template_formatter(template_language=template_options.language)
roles, templates = zip(*messages)
formatted_templates = map(
lambda template: template_formatter.format(template, **template_options.variables),
templates,
)
formatted_messages = zip(roles, formatted_templates)
return formatted_messages


def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
"""
Instantiates the appropriate template formatter for the template language.
"""
if template_language is TemplateLanguage.MUSTACHE:
return MustacheTemplateFormatter()
if template_language is TemplateLanguage.F_STRING:
return FStringTemplateFormatter()
assert_never(template_language)


JSON = OpenInferenceMimeTypeValues.JSON.value

LLM = OpenInferenceSpanKindValues.LLM.value
Expand Down
Loading

0 comments on commit d0b1641

Please sign in to comment.