Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update semantic model length validation for literal retrieval #72

Merged
merged 5 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions semantic_model_generator/data_processing/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class Column:
column_name: str
column_type: str
values: Optional[List[str]] = None
comment: Optional[str] = (
None # comment field's to save the column comment user specified on the column
)
comment: Optional[
str
] = None # comment field's to save the column comment user specified on the column

def __post_init__(self: Any) -> None:
"""
Expand All @@ -37,9 +37,9 @@ class Table:
id_: int
name: str
columns: List[Column]
comment: Optional[str] = (
None # comment field's to save the table comment user specified on the table
)
comment: Optional[
str
] = None # comment field's to save the table comment user specified on the table

def __post_init__(self: Any) -> None:
for col in self.columns:
Expand Down
2 changes: 2 additions & 0 deletions semantic_model_generator/protos/semantic_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ message Dimension {
bool unique = 6 [(optional) = true];
// Sample values of this column.
repeated string sample_values = 7 [(optional) = true];
// Name of a Cortex Search Service configured on this column.
string cortex_search_service_name = 12 [(optional) = true];
}

// Time dimension columns contain time values (e.g. sale_date, created_at, year).
Expand Down
116 changes: 59 additions & 57 deletions semantic_model_generator/protos/semantic_model_pb2.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions semantic_model_generator/protos/semantic_model_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ OPTIONAL_FIELD_NUMBER: _ClassVar[int]
optional: _descriptor.FieldDescriptor

class Dimension(_message.Message):
__slots__ = ("name", "synonyms", "description", "expr", "data_type", "unique", "sample_values")
__slots__ = ("name", "synonyms", "description", "expr", "data_type", "unique", "sample_values", "cortex_search_service_name")
NAME_FIELD_NUMBER: _ClassVar[int]
SYNONYMS_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
EXPR_FIELD_NUMBER: _ClassVar[int]
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
UNIQUE_FIELD_NUMBER: _ClassVar[int]
SAMPLE_VALUES_FIELD_NUMBER: _ClassVar[int]
CORTEX_SEARCH_SERVICE_NAME_FIELD_NUMBER: _ClassVar[int]
name: str
synonyms: _containers.RepeatedScalarFieldContainer[str]
description: str
expr: str
data_type: str
unique: bool
sample_values: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, name: _Optional[str] = ..., synonyms: _Optional[_Iterable[str]] = ..., description: _Optional[str] = ..., expr: _Optional[str] = ..., data_type: _Optional[str] = ..., unique: bool = ..., sample_values: _Optional[_Iterable[str]] = ...) -> None: ...
cortex_search_service_name: str
def __init__(self, name: _Optional[str] = ..., synonyms: _Optional[_Iterable[str]] = ..., description: _Optional[str] = ..., expr: _Optional[str] = ..., data_type: _Optional[str] = ..., unique: bool = ..., sample_values: _Optional[_Iterable[str]] = ..., cortex_search_service_name: _Optional[str] = ...) -> None: ...

class TimeDimension(_message.Message):
__slots__ = ("name", "synonyms", "description", "expr", "data_type", "unique", "sample_values")
Expand Down
30 changes: 27 additions & 3 deletions semantic_model_generator/tests/generate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def mock_snowflake_connection():
comment=None,
)

_TABLE_WITH_THAT_EXCEEDS_CONTEXT = Table(
_TABLE_WITH_MANY_SAMPLE_VALUES = Table(
id_=0,
name="PRODUCTS",
columns=[
Expand All @@ -153,6 +153,22 @@ def mock_snowflake_connection():
comment=None,
)

_TABLE_THAT_EXCEEDS_CONTEXT = Table(
id_=0,
name="PRODUCTS",
columns=[
Column(
id_=i,
column_name=f"column_{i}",
column_type="NUMBER",
values=["1", "2", "3"],
comment=None,
)
for i in range(200)
],
comment=None,
)


@pytest.fixture
def mock_snowflake_connection_env(monkeypatch):
Expand Down Expand Up @@ -298,7 +314,7 @@ def mock_dependencies_exceed_context(mock_snowflake_connection):
valid_schemas_tables_columns_df_zip_code,
]
table_representations = [
_TABLE_WITH_THAT_EXCEEDS_CONTEXT, # Value returned on the first call.
_TABLE_THAT_EXCEEDS_CONTEXT, # Value returned on the first call.
]

with patch(
Expand Down Expand Up @@ -494,7 +510,15 @@ def test_generate_base_context_from_table_that_has_too_long_context(

mock_file.assert_called_once_with(output_path, "w")
mock_logger.warning.assert_called_once_with(
"WARNING 🚨: The Semantic model is too large. \n Passed size is 26867 characters. We need you to remove 784 characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values. \n Once you've finished updating, please validate your semantic model."
"WARNING 🚨: "
"The Semantic model is too large. \n"
"Passed size is 41701 characters. "
"We need you to remove 16180 characters in your semantic model. "
"Please check: \n "
"(1) If you have long descriptions that can be truncated. \n "
"(2) If you can remove some columns that are not used within your tables. \n "
"(3) If you have extra tables you do not need. \n "
"Once you've finished updating, please validate your semantic model."
)

mock_file.assert_called_once_with(output_path, "w")
Expand Down
32 changes: 31 additions & 1 deletion semantic_model_generator/tests/samples/validate_yamls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from semantic_model_generator.protos import semantic_model_pb2

_VALID_YAML = """name: my test semantic model
tables:
- name: ALIAS
Expand Down Expand Up @@ -279,7 +281,7 @@
"""


_VALID_YAML_TOO_LONG_CONTEXT = """name: my test semantic model
_INVALID_YAML_TOO_LONG_CONTEXT = """name: my test semantic model
tables:
- name: ALIAS
base_table:
Expand Down Expand Up @@ -339,3 +341,31 @@
data_type: TEXT
sample_values: ['Holtsville', 'Adjuntas', 'Boqueron']
"""

_VALID_YAML_MANY_SAMPLE_VALUES = semantic_model_pb2.SemanticModel(
name="test model",
tables=[
semantic_model_pb2.Table(
name="ALIAS",
base_table=semantic_model_pb2.FullyQualifiedTable(
database="AUTOSQL_DATASET_BIRD_V2", schema="ADDRESS", table="ALIAS"
),
dimensions=[
semantic_model_pb2.Dimension(
name=f"DIMENSION_{i}",
expr="ALIAS",
data_type="TEXT",
sample_values=[
"apple",
"banana",
"cantaloupe",
"date",
"elderberry",
]
* 100,
)
for i in range(5)
],
)
],
)
32 changes: 25 additions & 7 deletions semantic_model_generator/tests/validate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from strictyaml import DuplicateKeysDisallowed, YAMLValidationError

from semantic_model_generator.data_processing.proto_utils import proto_to_yaml
from semantic_model_generator.tests.samples import validate_yamls
from semantic_model_generator.validate_model import validate_from_local_path

Expand Down Expand Up @@ -81,10 +82,10 @@ def temp_invalid_yaml_incorrect_dtype():


@pytest.fixture
def temp_valid_yaml_too_long_context():
def temp_invalid_yaml_too_long_context():
"""Create a temporary YAML file with the test data."""
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(validate_yamls._VALID_YAML_TOO_LONG_CONTEXT)
tmp.write(validate_yamls._INVALID_YAML_TOO_LONG_CONTEXT)
tmp.flush()
yield tmp.name

Expand Down Expand Up @@ -219,13 +220,30 @@ def test_invalid_yaml_incorrect_datatype(


@mock.patch("semantic_model_generator.validate_model.logger")
def test_valid_yaml_too_long_context(
mock_logger, temp_valid_yaml_too_long_context, mock_snowflake_connection
def test_invalid_yaml_too_long_context(
mock_logger, temp_invalid_yaml_too_long_context, mock_snowflake_connection
):
account_name = "snowflake test"
with pytest.raises(ValueError) as exc_info:
validate_from_local_path(temp_valid_yaml_too_long_context, account_name)

expected_error = "Your semantic model is too large. Passed size is 41937 characters. We need you to remove 15856 characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values."
validate_from_local_path(temp_invalid_yaml_too_long_context, account_name)

expected_error = (
"Your semantic model is too large. "
"Passed size is 41937 characters. "
"We need you to remove 16416 characters in your semantic model. Please check: \n"
" (1) If you have long descriptions that can be truncated. \n"
" (2) If you can remove some columns that are not used within your tables. \n"
" (3) If you have extra tables you do not need."
)

assert expected_error in str(exc_info.value), "Unexpected error message"


@mock.patch("semantic_model_generator.validate_model.logger")
def test_valid_yaml_many_sample_values(mock_logger, mock_snowflake_connection):
account_name = "snowflake test"
yaml = proto_to_yaml(validate_yamls._VALID_YAML_MANY_SAMPLE_VALUES)
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(yaml)
tmp.flush()
assert validate_from_local_path(tmp.name, account_name) is None
92 changes: 84 additions & 8 deletions semantic_model_generator/validate/context_length.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,66 @@
from typing import TypeVar
from typing import Any, TypeVar

from google.protobuf.message import Message
from loguru import logger

from semantic_model_generator.data_processing.proto_utils import proto_to_yaml

ProtoMsg = TypeVar("ProtoMsg", bound=Message)
_MODEL_CONTEXT_LENGTH_TOKENS = 6500 # We use 6.5k, with 1.2k for instructions, so that we can reserve 500 for response tokens (average is 300).
_MODEL_CONTEXT_INSTR_TOKEN = 20 # buffer for instr tokens

# Max total tokens is 8200.
# We reserve 500 tokens for response (average response is 300 tokens).
# So the prompt token limit is 7700.
_TOTAL_PROMPT_TOKEN_LIMIT = 7700
_BASE_INSTRUCTION_TOKEN_LENGTH = 1220
_TOKENS_PER_LITERAL_RETRIEVAL = 100
sfc-gh-nlimtiaco marked this conversation as resolved.
Show resolved Hide resolved

sfc-gh-nlimtiaco marked this conversation as resolved.
Show resolved Hide resolved
# As per https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
_CHARS_PER_TOKEN = 4

# Max number of sample values we include in the semantic model representation.
_MAX_SAMPLE_VALUES = 3


def _get_field(msg: ProtoMsg, field_name: str) -> Any:
fields = [value for fd, value in msg.ListFields() if fd.name == field_name]
if not fields:
return None
return fields[0]


def _truncate_sample_values(model: ProtoMsg) -> None:
tables = _get_field(model, "tables")
if not tables:
return
for table in tables:
dimensions = _get_field(table, "dimensions")
measures = _get_field(table, "measures")
if dimensions:
for dimension in dimensions:
sample_values = _get_field(dimension, "sample_values")
if sample_values:
del sample_values[_MAX_SAMPLE_VALUES:]
if measures:
for measure in measures:
sample_values = _get_field(measure, "sample_values")
if sample_values:
del sample_values[_MAX_SAMPLE_VALUES:]


def _count_search_services(model: ProtoMsg) -> int:
cnt = 0
tables = _get_field(model, "tables")
if not tables:
return 0

for table in tables:
dimensions = _get_field(table, "dimensions")
if not dimensions:
continue
for dimension in dimensions:
if _get_field(dimension, "cortex_search_service_name"):
cnt += 1
return cnt


def validate_context_length(model: ProtoMsg, throw_error: bool = False) -> None:
Expand All @@ -19,17 +72,40 @@ def validate_context_length(model: ProtoMsg, throw_error: bool = False) -> None:
"""

model.ClearField("verified_queries")
_truncate_sample_values(model)
num_search_services = _count_search_services(model)

yaml_str = proto_to_yaml(model)
# Pass in the str version of the semantic context yaml.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this'll happen often: but i don't recall we strip any whitespaces during loading yaml/convert to proto; Shall we add stripping for whitespaces?
We can potentially add to proto field options like what Daniel had here: https://github.com/snowflakedb/cortex/pull/114221/files

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if whitespace is a problem, have you seen any issues with it? For yaml -> proto, I don't know of strictyaml not being robust to extra white space. For proto -> yaml, I don't see where we would output excess whitespace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, probably strictyaml handles that. Can you help add a todo at the top of this file just reminder to keep an eye on the white spaces?

# This isn't exactly how many tokens the model will be, but should roughly be correct.
TOTAL_TOKENS_LIMIT = _MODEL_CONTEXT_LENGTH_TOKENS + _MODEL_CONTEXT_INSTR_TOKEN
CHARS_PER_TOKEN = 4 # as per https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
if len(yaml_str) // CHARS_PER_TOKEN > TOTAL_TOKENS_LIMIT:
literals_buffer = (
_TOKENS_PER_LITERAL_RETRIEVAL
+ num_search_services * _TOKENS_PER_LITERAL_RETRIEVAL
)
approx_instruction_length = _BASE_INSTRUCTION_TOKEN_LENGTH + literals_buffer
model_tokens_limit = _TOTAL_PROMPT_TOKEN_LIMIT - approx_instruction_length
model_tokens = len(yaml_str) // _CHARS_PER_TOKEN
if model_tokens > model_tokens_limit:
tokens_to_remove = model_tokens - model_tokens_limit
chars_to_remove = tokens_to_remove * _CHARS_PER_TOKEN
if throw_error:
raise ValueError(
f"Your semantic model is too large. Passed size is {len(yaml_str)} characters. We need you to remove {((len(yaml_str) // CHARS_PER_TOKEN)-TOTAL_TOKENS_LIMIT ) *CHARS_PER_TOKEN } characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values."
f"Your semantic model is too large. "
f"Passed size is {len(yaml_str)} characters. "
f"We need you to remove {chars_to_remove} characters in your semantic model. "
f"Please check: \n"
f" (1) If you have long descriptions that can be truncated. \n"
f" (2) If you can remove some columns that are not used within your tables. \n"
f" (3) If you have extra tables you do not need."
)
else:
logger.warning(
f"WARNING 🚨: The Semantic model is too large. \n Passed size is {len(yaml_str)} characters. We need you to remove {((len(yaml_str) // CHARS_PER_TOKEN)-TOTAL_TOKENS_LIMIT ) *CHARS_PER_TOKEN } characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values. \n Once you've finished updating, please validate your semantic model."
f"WARNING 🚨: The Semantic model is too large. \n"
f"Passed size is {len(yaml_str)} characters. "
f"We need you to remove {chars_to_remove} characters in your semantic model. "
f"Please check: \n"
f" (1) If you have long descriptions that can be truncated. \n"
f" (2) If you can remove some columns that are not used within your tables. \n"
f" (3) If you have extra tables you do not need. \n"
f" Once you've finished updating, please validate your semantic model."
)
Loading