Skip to content

Commit

Permalink
Convert to using mashumaro jsonschema with acceptable performance (#8437
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gshank authored and peterallenwebb committed Aug 30, 2023
1 parent 22216a3 commit 9097548
Show file tree
Hide file tree
Showing 64 changed files with 4,115 additions and 3,683 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230718-145428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Switch from hologram to mashumaro jsonschema
time: 2023-07-18T14:54:28.41453-04:00
custom:
Author: gshank
Issue: "8426"
14 changes: 13 additions & 1 deletion core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,21 @@ def initial_result(self, resource_type: NodeType, base: bool) -> C:

def _update_from_config(self, result: C, partial: Dict[str, Any], validate: bool = False) -> C:
translated = self._active_project.credentials.translate_aliases(partial)
return result.update_from(
translated = self.translate_hook_names(translated)
updated = result.update_from(
translated, self._active_project.credentials.type, validate=validate
)
return updated

def translate_hook_names(self, project_dict):
# This is a kind of kludge because the fix for #6411 specifically allowed misspelling
# the hook field names in dbt_project.yml, which only ever worked because we didn't
# run validate on the dbt_project configs.
if "pre_hook" in project_dict:
project_dict["pre-hook"] = project_dict.pop("pre_hook")
if "post_hook" in project_dict:
project_dict["post-hook"] = project_dict.pop("post_hook")
return project_dict

def calculate_node_config_dict(
self,
Expand Down
17 changes: 7 additions & 10 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@
from dbt.events.functions import fire_event
from dbt.events.types import NewConnectionOpening
from dbt.events.contextvars import get_node_info
from typing_extensions import Protocol
from typing_extensions import Protocol, Annotated
from dbt.dataclass_schema import (
dbtClassMixin,
StrEnum,
ExtensibleDbtClassMixin,
HyphenatedDbtClassMixin,
ValidatedStringMixin,
register_pattern,
)
from dbt.contracts.util import Replaceable
from mashumaro.jsonschema.annotations import Pattern


class Identifier(ValidatedStringMixin):
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"


# we need register_pattern for jsonschema validation
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")


@dataclass
class AdapterResponse(dbtClassMixin):
_message: str
Expand All @@ -55,7 +50,8 @@ class ConnectionState(StrEnum):

@dataclass(init=False)
class Connection(ExtensibleDbtClassMixin, Replaceable):
type: Identifier
# Annotated is used by mashumaro for jsonschema generation
type: Annotated[Identifier, Pattern(r"^[A-Za-z_][A-Za-z0-9_]+$")]
name: Optional[str] = None
state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False
Expand Down Expand Up @@ -161,6 +157,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
# Need to fixup dbname => database, pass => password
data = cls.translate_aliases(data)
return data

Expand Down Expand Up @@ -220,10 +217,10 @@ def to_target_dict(self):


@dataclass
class QueryComment(HyphenatedDbtClassMixin):
class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False
job_label: bool = field(default=False, metadata={"alias": "job-label"})


class AdapterRequiredConfig(HasCredentials, Protocol):
Expand Down
44 changes: 9 additions & 35 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from enum import Enum
from itertools import chain
from typing import Any, List, Optional, Dict, Union, Type, TypeVar, Callable
from typing_extensions import Annotated

from dbt.dataclass_schema import (
dbtClassMixin,
ValidationError,
register_pattern,
StrEnum,
)
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed, Docs
Expand All @@ -15,6 +15,7 @@
from dbt.exceptions import DbtInternalError, CompilationError
from dbt import hooks
from dbt.node_types import NodeType
from mashumaro.jsonschema.annotations import Pattern


M = TypeVar("M", bound="Metadata")
Expand Down Expand Up @@ -188,9 +189,6 @@ class Severity(str):
pass


register_pattern(Severity, insensitive_patterns("warn", "error"))


class OnConfigurationChangeOption(StrEnum):
Apply = "apply"
Continue = "continue"
Expand Down Expand Up @@ -376,15 +374,6 @@ def finalize_and_validate(self: T) -> T:
self.validate(dct)
return self.from_dict(dct)

def replace(self, **kwargs):
dct = self.to_dict(omit_none=True)

mapping = self.field_mapping()
for key, value in kwargs.items():
new_key = mapping.get(key, key)
dct[new_key] = value
return self.from_dict(dct)


@dataclass
class SemanticModelConfig(BaseConfig):
Expand Down Expand Up @@ -447,11 +436,11 @@ class NodeConfig(NodeAndTestConfig):
persist_docs: Dict[str, Any] = field(default_factory=dict)
post_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "post-hook"},
)
pre_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "pre-hook"},
)
quoting: Dict[str, Any] = field(
default_factory=dict,
Expand Down Expand Up @@ -511,30 +500,11 @@ def __post_init__(self):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
# create a new dict because otherwise it gets overwritten in
# tests
new_dict = {}
for key in data:
new_dict[key] = data[key]
data = new_dict
for key in hooks.ModelHookType:
if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
for field_name in field_map:
if field_name in data:
new_name = field_map[field_name]
data[new_name] = data.pop(field_name)
return data

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
for field_name in field_map:
if field_name in dct:
dct[field_map[field_name]] = dct.pop(field_name)
return dct

# this is still used by jsonschema validation
@classmethod
def field_mapping(cls):
Expand All @@ -554,6 +524,9 @@ def validate(cls, data):
raise ValidationError("A seed must have a materialized value of 'seed'")


SEVERITY_PATTERN = r"^([Ww][Aa][Rr][Nn]|[Ee][Rr][Rr][Oo][Rr])$"


@dataclass
class TestConfig(NodeAndTestConfig):
__test__ = False
Expand All @@ -564,7 +537,8 @@ class TestConfig(NodeAndTestConfig):
metadata=CompareBehavior.Exclude.meta(),
)
materialized: str = "test"
severity: Severity = Severity("ERROR")
# Annotated is used by mashumaro for jsonschema generation
severity: Annotated[Severity, Pattern(SEVERITY_PATTERN)] = Severity("ERROR")
store_failures: Optional[bool] = None
where: Optional[str] = None
limit: Optional[int] = None
Expand Down
36 changes: 18 additions & 18 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib

from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal

from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin

Expand Down Expand Up @@ -556,18 +556,18 @@ def depends_on_macros(self):

@dataclass
class AnalysisNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
resource_type: Literal[NodeType.Analysis]


@dataclass
class HookNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
resource_type: Literal[NodeType.Operation]
index: Optional[int] = None


@dataclass
class ModelNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
resource_type: Literal[NodeType.Model]
access: AccessType = AccessType.Protected
constraints: List[ModelLevelConstraint] = field(default_factory=list)
version: Optional[NodeVersion] = None
Expand Down Expand Up @@ -854,12 +854,12 @@ def same_contract(self, old, adapter_type=None) -> bool:
# TODO: rm?
@dataclass
class RPCNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
resource_type: Literal[NodeType.RPCCall]


@dataclass
class SqlNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]})
resource_type: Literal[NodeType.SqlOperation]


# ====================================
Expand All @@ -869,7 +869,7 @@ class SqlNode(CompiledNode):

@dataclass
class SeedNode(ParsedNode): # No SQLDefaults!
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
resource_type: Literal[NodeType.Seed]
config: SeedConfig = field(default_factory=SeedConfig)
# seeds need the root_path because the contents are not loaded initially
# and we need the root_path to load the seed later
Expand Down Expand Up @@ -995,7 +995,7 @@ def is_relational(self):

@dataclass
class SingularTestNode(TestShouldStoreFailures, CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
# Was not able to make mypy happy and keep the code working. We need to
# refactor the various configs.
config: TestConfig = field(default_factory=TestConfig) # type: ignore
Expand Down Expand Up @@ -1031,7 +1031,7 @@ class HasTestMetadata(dbtClassMixin):

@dataclass
class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
column_name: Optional[str] = None
file_key_name: Optional[str] = None
# Was not able to make mypy happy and keep the code working. We need to
Expand Down Expand Up @@ -1064,13 +1064,13 @@ class IntermediateSnapshotNode(CompiledNode):
# uses a regular node config, which the snapshot parser will then convert
# into a full ParsedSnapshotNode after rendering. Note: it currently does
# not work to set snapshot config in schema files because of the validation.
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)


@dataclass
class SnapshotNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

Expand All @@ -1083,7 +1083,7 @@ class SnapshotNode(CompiledNode):
@dataclass
class Macro(BaseNode):
macro_sql: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
resource_type: Literal[NodeType.Macro]
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -1113,7 +1113,7 @@ def depends_on_macros(self):
@dataclass
class Documentation(BaseNode):
block_contents: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Documentation]})
resource_type: Literal[NodeType.Documentation]

@property
def search_name(self):
Expand Down Expand Up @@ -1144,7 +1144,7 @@ class UnpatchedSourceDefinition(BaseNode):
source: UnparsedSourceDefinition
table: UnparsedSourceTableDefinition
fqn: List[str]
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]
patch_path: Optional[str] = None

def get_full_source_name(self):
Expand Down Expand Up @@ -1189,7 +1189,7 @@ class ParsedSourceMandatory(GraphNode, HasRelationMetadata):
source_description: str
loader: str
identifier: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]


@dataclass
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def search_name(self):
class Exposure(GraphNode):
type: ExposureType
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Exposure]})
resource_type: Literal[NodeType.Exposure]
description: str = ""
label: Optional[str] = None
maturity: Optional[MaturityType] = None
Expand Down Expand Up @@ -1465,7 +1465,7 @@ class Metric(GraphNode):
type_params: MetricTypeParams
filter: Optional[WhereFilter] = None
metadata: Optional[SourceFileMetadata] = None
resource_type: NodeType = field(metadata={"restrict": [NodeType.Metric]})
resource_type: Literal[NodeType.Metric]
meta: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
config: MetricConfig = field(default_factory=MetricConfig)
Expand Down Expand Up @@ -1548,7 +1548,7 @@ def same_contents(self, old: Optional["Metric"]) -> bool:
class Group(BaseNode):
name: str
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Group]})
resource_type: Literal[NodeType.Group]


# ====================================
Expand Down
Loading

0 comments on commit 9097548

Please sign in to comment.