Skip to content

Commit

Permalink
fix tests and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nlimtiaco committed Jun 25, 2024
1 parent 7337168 commit c743e17
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 19 deletions.
12 changes: 6 additions & 6 deletions semantic_model_generator/data_processing/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class Column:
column_name: str
column_type: str
values: Optional[List[str]] = None
comment: Optional[str] = (
None # comment field's to save the column comment user specified on the column
)
comment: Optional[
str
] = None # comment field's to save the column comment user specified on the column

def __post_init__(self: Any) -> None:
"""
Expand All @@ -37,9 +37,9 @@ class Table:
id_: int
name: str
columns: List[Column]
comment: Optional[str] = (
None # comment field's to save the table comment user specified on the table
)
comment: Optional[
str
] = None # comment field's to save the table comment user specified on the table

def __post_init__(self: Any) -> None:
for col in self.columns:
Expand Down
30 changes: 27 additions & 3 deletions semantic_model_generator/tests/generate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def mock_snowflake_connection():
comment=None,
)

_TABLE_WITH_THAT_EXCEEDS_CONTEXT = Table(
_TABLE_WITH_MANY_SAMPLE_VALUES = Table(
id_=0,
name="PRODUCTS",
columns=[
Expand All @@ -153,6 +153,22 @@ def mock_snowflake_connection():
comment=None,
)

_TABLE_THAT_EXCEEDS_CONTEXT = Table(
id_=0,
name="PRODUCTS",
columns=[
Column(
id_=i,
column_name=f"column_{i}",
column_type="NUMBER",
values=["1", "2", "3"],
comment=None,
)
for i in range(200)
],
comment=None,
)


@pytest.fixture
def mock_snowflake_connection_env(monkeypatch):
Expand Down Expand Up @@ -298,7 +314,7 @@ def mock_dependencies_exceed_context(mock_snowflake_connection):
valid_schemas_tables_columns_df_zip_code,
]
table_representations = [
_TABLE_WITH_THAT_EXCEEDS_CONTEXT, # Value returned on the first call.
_TABLE_THAT_EXCEEDS_CONTEXT, # Value returned on the first call.
]

with patch(
Expand Down Expand Up @@ -494,7 +510,15 @@ def test_generate_base_context_from_table_that_has_too_long_context(

mock_file.assert_called_once_with(output_path, "w")
mock_logger.warning.assert_called_once_with(
"WARNING 🚨: The Semantic model is too large. \n Passed size is 26867 characters. We need you to remove 784 characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values. \n Once you've finished updating, please validate your semantic model."
"WARNING 🚨: "
"The Semantic model is too large. \n"
"Passed size is 41701 characters. "
"We need you to remove 16180 characters in your semantic model. "
"Please check: \n "
"(1) If you have long descriptions that can be truncated. \n "
"(2) If you can remove some columns that are not used within your tables. \n "
"(3) If you have extra tables you do not need. \n "
"Once you've finished updating, please validate your semantic model."
)

mock_file.assert_called_once_with(output_path, "w")
Expand Down
32 changes: 31 additions & 1 deletion semantic_model_generator/tests/samples/validate_yamls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from semantic_model_generator.protos import semantic_model_pb2

_VALID_YAML = """name: my test semantic model
tables:
- name: ALIAS
Expand Down Expand Up @@ -279,7 +281,7 @@
"""


_VALID_YAML_TOO_LONG_CONTEXT = """name: my test semantic model
_INVALID_YAML_TOO_LONG_CONTEXT = """name: my test semantic model
tables:
- name: ALIAS
base_table:
Expand Down Expand Up @@ -339,3 +341,31 @@
data_type: TEXT
sample_values: ['Holtsville', 'Adjuntas', 'Boqueron']
"""

_VALID_YAML_MANY_SAMPLE_VALUES = semantic_model_pb2.SemanticModel(
name="test model",
tables=[
semantic_model_pb2.Table(
name="ALIAS",
base_table=semantic_model_pb2.FullyQualifiedTable(
database="AUTOSQL_DATASET_BIRD_V2", schema="ADDRESS", table="ALIAS"
),
dimensions=[
semantic_model_pb2.Dimension(
name=f"DIMENSION_{i}",
expr="ALIAS",
data_type="TEXT",
sample_values=[
"apple",
"banana",
"cantaloupe",
"date",
"elderberry",
]
* 100,
)
for i in range(5)
],
)
],
)
32 changes: 25 additions & 7 deletions semantic_model_generator/tests/validate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from strictyaml import DuplicateKeysDisallowed, YAMLValidationError

from semantic_model_generator.data_processing.proto_utils import proto_to_yaml
from semantic_model_generator.tests.samples import validate_yamls
from semantic_model_generator.validate_model import validate_from_local_path

Expand Down Expand Up @@ -81,10 +82,10 @@ def temp_invalid_yaml_incorrect_dtype():


@pytest.fixture
def temp_valid_yaml_too_long_context():
def temp_invalid_yaml_too_long_context():
"""Create a temporary YAML file with the test data."""
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(validate_yamls._VALID_YAML_TOO_LONG_CONTEXT)
tmp.write(validate_yamls._INVALID_YAML_TOO_LONG_CONTEXT)
tmp.flush()
yield tmp.name

Expand Down Expand Up @@ -219,13 +220,30 @@ def test_invalid_yaml_incorrect_datatype(


@mock.patch("semantic_model_generator.validate_model.logger")
def test_valid_yaml_too_long_context(
mock_logger, temp_valid_yaml_too_long_context, mock_snowflake_connection
def test_invalid_yaml_too_long_context(
mock_logger, temp_invalid_yaml_too_long_context, mock_snowflake_connection
):
account_name = "snowflake test"
with pytest.raises(ValueError) as exc_info:
validate_from_local_path(temp_valid_yaml_too_long_context, account_name)

expected_error = "Your semantic model is too large. Passed size is 41937 characters. We need you to remove 15856 characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values."
validate_from_local_path(temp_invalid_yaml_too_long_context, account_name)

expected_error = (
"Your semantic model is too large. "
"Passed size is 41937 characters. "
"We need you to remove 16416 characters in your semantic model. Please check: \n"
" (1) If you have long descriptions that can be truncated. \n"
" (2) If you can remove some columns that are not used within your tables. \n"
" (3) If you have extra tables you do not need."
)

assert expected_error in str(exc_info.value), "Unexpected error message"


@mock.patch("semantic_model_generator.validate_model.logger")
def test_valid_yaml_many_sample_values(mock_logger, mock_snowflake_connection):
account_name = "snowflake test"
yaml = proto_to_yaml(validate_yamls._VALID_YAML_MANY_SAMPLE_VALUES)
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(yaml)
tmp.flush()
assert validate_from_local_path(tmp.name, account_name) is None
7 changes: 5 additions & 2 deletions semantic_model_generator/validate/context_length.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Any
from typing import Any, TypeVar

from google.protobuf.message import Message
from loguru import logger
Expand Down Expand Up @@ -78,7 +78,10 @@ def validate_context_length(model: ProtoMsg, throw_error: bool = False) -> None:
yaml_str = proto_to_yaml(model)
# Pass in the str version of the semantic context yaml.
# This isn't exactly how many tokens the model will be, but should roughly be correct.
literals_buffer = _TOKENS_PER_LITERAL_RETRIEVAL + num_search_services * _TOKENS_PER_LITERAL_RETRIEVAL
literals_buffer = (
_TOKENS_PER_LITERAL_RETRIEVAL
+ num_search_services * _TOKENS_PER_LITERAL_RETRIEVAL
)
approx_instruction_length = _BASE_INSTRUCTION_TOKEN_LENGTH + literals_buffer
model_tokens_limit = _TOTAL_PROMPT_TOKEN_LIMIT - approx_instruction_length
model_tokens = len(yaml_str) // _CHARS_PER_TOKEN
Expand Down

0 comments on commit c743e17

Please sign in to comment.