Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support "just in time" loading of records, and add ID fields #154

Merged
merged 10 commits into from
Jun 26, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240618-155025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Deserialize Record objects on a just-in-time basis.
time: 2024-06-18T15:50:25.985387-04:00
custom:
Author: peterallenwebb
Issue: "151"
4 changes: 3 additions & 1 deletion dbt_common/clients/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _include(self) -> bool:
# Do not record or replay filesystem searches that were performed against
# files which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.root_path
"dbt/include"
not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse?
and "dbt/include/global_project" not in self.root_path
and "/plugins/postgres/dbt/include/" not in self.root_path
)

Expand Down
71 changes: 45 additions & 26 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
external systems during a command invocation, so that the command can be re-run
later with the recording 'replayed' to dbt.

The rationale for and architecture of this module is described in detail in the
The rationale for and architecture of this module are described in detail in the
docs/guides/record_replay.md document in this repository.
"""
import functools
Expand All @@ -12,7 +12,7 @@

from deepdiff import DeepDiff # type: ignore
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional, Type
from typing import Any, Callable, Dict, List, Mapping, Optional, Type

from dbt_common.context import get_invocation_context

Expand Down Expand Up @@ -129,10 +129,11 @@ def __init__(
previous_recording_path: Optional[str] = None,
) -> None:
self.mode = mode
self.types = types
self.recorded_types = types
self._records_by_type: Dict[str, List[Record]] = {}
self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {}
self._replay_diffs: List["Diff"] = []
self.diff: Diff
self.diff: Optional[Diff] = None
self.previous_recording_path = previous_recording_path
self.current_recording_path = current_recording_path

Expand All @@ -146,7 +147,7 @@ def __init__(
)

if self.mode == RecorderMode.REPLAY:
self._records_by_type = self.load(self.previous_recording_path)
self._unprocessed_records_by_type = self.load(self.previous_recording_path)

@classmethod
def register_record_type(cls, rec_type) -> Any:
Expand All @@ -161,7 +162,14 @@ def add_record(self, record: Record) -> None:
self._records_by_type[rec_cls_name].append(record)

def pop_matching_record(self, params: Any) -> Optional[Record]:
rec_type_name = self._record_name_by_params_name[type(params).__name__]
rec_type_name = self._record_name_by_params_name.get(type(params).__name__)

if rec_type_name is None:
raise Exception(
f"A record of type {type(params).__name__} was requested, but no such type has been registered."
)

self._ensure_records_processed(rec_type_name)
records = self._records_by_type[rec_type_name]
match: Optional[Record] = None
for rec in records:
Expand All @@ -186,39 +194,40 @@ def _to_dict(self) -> Dict:
return dct

@classmethod
def load(cls, file_name: str) -> Dict[str, List[Record]]:
def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]:
with open(file_name) as file:
loaded_dct = json.load(file)
return json.load(file)

records_by_type: Dict[str, List[Record]] = {}
def _ensure_records_processed(self, record_type_name: str) -> None:
if record_type_name in self._records_by_type:
return

for record_type_name in loaded_dct:
# TODO: this breaks with QueryRecord on replay since it's
# not in common so isn't part of cls._record_cls_by_name yet
record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list
return records_by_type
rec_list = []
record_cls = self._record_cls_by_name[record_type_name]
for record_dct in self._unprocessed_records_by_type[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
self._records_by_type[record_type_name] = rec_list

def expect_record(self, params: Any) -> Any:
record = self.pop_matching_record(params)

if record is None:
raise Exception()

if record.result is None:
return None

result_tuple = dataclasses.astuple(record.result)
return result_tuple[0] if len(result_tuple) == 1 else result_tuple

def write_diffs(self, diff_file_name) -> None:
json.dump(
self.diff.calculate_diff(),
open(diff_file_name, "w"),
)
assert self.diff is not None
with open(diff_file_name, "w") as f:
json.dump(self.diff.calculate_diff(), f)

def print_diffs(self) -> None:
assert self.diff is not None
print(repr(self.diff.calculate_diff()))


Expand Down Expand Up @@ -273,7 +282,12 @@ def get_record_types_from_dict(fp: str) -> List:
return list(loaded_dct.keys())


def record_function(record_type, method=False, tuple_result=False):
def record_function(
record_type,
method: bool = False,
tuple_result: bool = False,
id_field_name: Optional[str] = None,
) -> Callable:
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
Expand All @@ -291,12 +305,17 @@ def record_replay_wrapper(*args, **kwargs):
if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.types is not None and record_type.__name__ not in recorder.types:
if (
recorder.recorded_types is not None
and record_type.__name__ not in recorder.recorded_types
):
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
param_args = (getattr(args[0], id_field_name),) + param_args

params = record_type.params_cls(*param_args, **kwargs)

Expand All @@ -313,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs):
r = func_to_record(*args, **kwargs)
result = (
None
if r is None or record_type.result_cls is None
if record_type.result_cls is None
else record_type.result_cls(*r)
if tuple_result
else record_type.result_cls(r)
Expand Down
Loading