Skip to content

Commit

Permalink
fix(playground): plumb through and record invocation parameters (#5005)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 16, 2024
1 parent ab09109 commit 9d375e5
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 46 deletions.
12 changes: 11 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ union Bin = NominalBin | IntervalBin | MissingValueBin
input ChatCompletionInput {
messages: [ChatCompletionMessageInput!]!
model: GenerativeModelInput!
apiKey: String = null
invocationParameters: InvocationParameters!
apiKey: String
}

input ChatCompletionMessageInput {
Expand Down Expand Up @@ -895,6 +896,15 @@ type IntervalBin {
range: NumericRange!
}

input InvocationParameters {
temperature: Float
maxCompletionTokens: Int
maxTokens: Int
topP: Float
stop: [String!]
seed: Int
}

"""
The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf).
"""
Expand Down
11 changes: 10 additions & 1 deletion app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,16 @@ function useChatCompletionSubscription({
subscription PlaygroundOutputSubscription(
$messages: [ChatCompletionMessageInput!]!
$model: GenerativeModelInput!
$invocationParameters: InvocationParameters!
$apiKey: String
) {
chatCompletion(
input: { messages: $messages, model: $model, apiKey: $apiKey }
input: {
messages: $messages
model: $model
invocationParameters: $invocationParameters
apiKey: $apiKey
}
)
}
`,
Expand Down Expand Up @@ -187,6 +193,9 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
providerKey: instance.model.provider,
name: instance.model.modelName || "",
},
invocationParameters: {
temperature: 0.1, // TODO: add invocation parameters
},
apiKey: credentials[instance.model.provider],
},
runId: instance.activeRunId,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
Expand Down
12 changes: 6 additions & 6 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
Expand Down Expand Up @@ -271,7 +271,7 @@ class LatencyMs(expression.FunctionElement[float]):
name = "latency_ms"


@compiles(LatencyMs) # type: ignore
@compiles(LatencyMs)
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
Expand All @@ -287,7 +287,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any:
)


@compiles(LatencyMs, "sqlite") # type: ignore
@compiles(LatencyMs, "sqlite")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
Expand All @@ -308,21 +308,21 @@ class TextContains(expression.FunctionElement[str]):
name = "text_contains"


@compiles(TextContains) # type: ignore
@compiles(TextContains)
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(string.contains(substring), **kw)


@compiles(TextContains, "postgresql") # type: ignore
@compiles(TextContains, "postgresql")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(func.strpos(string, substring) > 0, **kw)


@compiles(TextContains, "sqlite") # type: ignore
@compiles(TextContains, "sqlite")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
Expand Down
18 changes: 18 additions & 0 deletions src/phoenix/server/api/input_types/InvocationParameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List, Optional

import strawberry
from strawberry import UNSET


@strawberry.input
class InvocationParameters:
"""
Invocation parameters interface shared between different providers.
"""

temperature: Optional[float] = UNSET
max_completion_tokens: Optional[int] = UNSET
max_tokens: Optional[int] = UNSET
top_p: Optional[float] = UNSET
stop: Optional[List[str]] = UNSET
seed: Optional[int] = UNSET
45 changes: 20 additions & 25 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import json
from dataclasses import asdict
from datetime import datetime
from enum import Enum
from itertools import chain
from json import JSONEncoder
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple

import strawberry
from openinference.instrumentation import safe_json_dumps
from openinference.semconv.trace import (
MessageAttributes,
OpenInferenceMimeTypeValues,
Expand All @@ -17,18 +14,20 @@
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import StatusCode
from pydantic import BaseModel
from sqlalchemy import insert, select
from strawberry import UNSET
from strawberry.types import Info
from typing_extensions import assert_never

from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
from phoenix.server.dml_event import SpanInsertEvent
from phoenix.trace.attributes import unflatten
from phoenix.utilities.json import jsonify

if TYPE_CHECKING:
from openai.types.chat import (
Expand All @@ -48,7 +47,8 @@ class GenerativeModelInput:
class ChatCompletionInput:
messages: List[ChatCompletionMessageInput]
model: GenerativeModelInput
api_key: Optional[str] = None
invocation_parameters: InvocationParameters
api_key: Optional[str] = UNSET


def to_openai_chat_completion_param(
Expand Down Expand Up @@ -94,7 +94,9 @@ async def chat_completion(
) -> AsyncIterator[str]:
from openai import AsyncOpenAI

client = AsyncOpenAI(api_key=input.api_key)
api_key = input.api_key or None
client = AsyncOpenAI(api_key=api_key)
invocation_parameters = jsonify(input.invocation_parameters)

in_memory_span_exporter = InMemorySpanExporter()
tracer_provider = TracerProvider()
Expand All @@ -109,8 +111,9 @@ async def chat_completion(
chain(
_llm_span_kind(),
_llm_model_name(input.model.name),
_input_value_and_mime_type(input),
_llm_input_messages(input.messages),
_llm_invocation_parameters(invocation_parameters),
_input_value_and_mime_type(input),
)
),
) as span:
Expand All @@ -121,6 +124,7 @@ async def chat_completion(
messages=(to_openai_chat_completion_param(message) for message in input.messages),
model=input.model.name,
stream=True,
**invocation_parameters,
):
chunks.append(chunk)
choice = chunk.choices[0]
Expand Down Expand Up @@ -206,14 +210,18 @@ def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
yield LLM_MODEL_NAME, model_name


def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
yield INPUT_MIME_TYPE, JSON
yield INPUT_VALUE, json.dumps(asdict(input), cls=GraphQLInputJSONEncoder)
yield INPUT_VALUE, safe_json_dumps(jsonify(input))


def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
yield OUTPUT_MIME_TYPE, JSON
yield OUTPUT_VALUE, json.dumps(output, cls=ChatCompletionOutputJSONEncoder)
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))


def _llm_input_messages(messages: List[ChatCompletionMessageInput]) -> Iterator[Tuple[str, Any]]:
Expand Down Expand Up @@ -242,20 +250,6 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime:
return datetime.fromtimestamp(epoch_seconds)


class GraphQLInputJSONEncoder(JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, Enum):
return obj.value
return super().default(obj)


class ChatCompletionOutputJSONEncoder(JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, BaseModel):
return obj.model_dump()
return super().default(obj)


JSON = OpenInferenceMimeTypeValues.JSON.value

LLM = OpenInferenceSpanKindValues.LLM.value
Expand All @@ -268,6 +262,7 @@ def default(self, obj: Any) -> Any:
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS

MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
Loading

0 comments on commit 9d375e5

Please sign in to comment.