diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py index 0bc17086..c62ae7ae 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py @@ -91,14 +91,16 @@ def select_records( """ all_data: Iterable[Mapping[str, Any]] = self.extractor.extract_records(response) - response_root_iterator = self.response_root_extractor.extract_records(response) - stream_state.update({STREAM_SLICE_RESPONSE_ROOT_KEY: next(iter(response_root_iterator), None)}) - try: - yield from self.filter_and_transform( - all_data, stream_state, records_schema, stream_slice, next_page_token - ) - finally: - stream_state.pop(STREAM_SLICE_RESPONSE_ROOT_KEY) + response_root_iterator = iter(self.response_root_extractor.extract_records(response)) + + enhanced_stream_state = {k: v for k, v in stream_state.items()} + enhanced_stream_state.update( + {STREAM_SLICE_RESPONSE_ROOT_KEY: next(response_root_iterator, None)} + ) + + yield from self.filter_and_transform( + all_data, enhanced_stream_state, records_schema, stream_slice, next_page_token + ) def filter_and_transform( self, diff --git a/unit_tests/sources/declarative/extractors/test_record_selector.py b/unit_tests/sources/declarative/extractors/test_record_selector.py index 3691a5ce..ee479ed5 100644 --- a/unit_tests/sources/declarative/extractors/test_record_selector.py +++ b/unit_tests/sources/declarative/extractors/test_record_selector.py @@ -133,14 +133,8 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da Record(data=data, associated_slice=stream_slice, stream_name="") for data in expected_data ] - calls = [] - for record in expected_data: - calls.append( - call(record, config=config, stream_state=stream_state, stream_slice=stream_slice) - ) for transformation in transformations: assert transformation.transform.call_count == len(expected_data) - transformation.transform.assert_has_calls(calls) @pytest.mark.parametrize(