From fb99f92f97e789c7dac2e4f6560a272607113b3d Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sat, 9 Nov 2024 22:44:59 -0800 Subject: [PATCH] apply auto-fixes --- airbyte_cdk/config_observation.py | 13 +- airbyte_cdk/connector.py | 28 +- .../connector_builder_handler.py | 13 +- airbyte_cdk/connector_builder/main.py | 25 +- .../connector_builder/message_grouper.py | 128 ++-- airbyte_cdk/connector_builder/models.py | 49 +- airbyte_cdk/destinations/destination.py | 18 +- .../destinations/vector_db_based/config.py | 41 +- .../vector_db_based/document_processor.py | 48 +- .../destinations/vector_db_based/embedder.py | 75 ++- .../destinations/vector_db_based/indexer.py | 34 +- .../vector_db_based/test_utils.py | 10 +- .../destinations/vector_db_based/utils.py | 13 +- .../destinations/vector_db_based/writer.py | 19 +- airbyte_cdk/entrypoint.py | 34 +- airbyte_cdk/exception_handler.py | 11 +- airbyte_cdk/logger.py | 32 +- airbyte_cdk/models/airbyte_protocol.py | 46 +- .../models/airbyte_protocol_serializers.py | 14 +- .../models/file_transfer_record_message.py | 9 +- airbyte_cdk/models/well_known_types.py | 1 + airbyte_cdk/sources/abstract_source.py | 42 +- .../concurrent_read_processor.py | 43 +- .../concurrent_source/concurrent_source.py | 14 +- .../concurrent_source_adapter.py | 33 +- ...partition_generation_completed_sentinel.py | 11 +- .../stream_thread_exception.py | 5 +- .../concurrent_source/thread_pool_manager.py | 24 +- airbyte_cdk/sources/config.py | 8 +- .../sources/connector_state_manager.py | 71 +-- .../sources/declarative/async_job/job.py | 13 +- .../declarative/async_job/job_orchestrator.py | 98 ++-- .../declarative/async_job/job_tracker.py | 11 +- .../declarative/async_job/repository.py | 9 +- .../sources/declarative/async_job/status.py | 6 +- .../sources/declarative/async_job/timer.py | 9 +- .../auth/declarative_authenticator.py | 10 +- airbyte_cdk/sources/declarative/auth/jwt.py | 58 +- airbyte_cdk/sources/declarative/auth/oauth.py | 50 +- .../auth/selective_authenticator.py | 9 +- airbyte_cdk/sources/declarative/auth/token.py | 59 +- .../declarative/auth/token_provider.py | 20 +- .../declarative/checks/check_stream.py | 13 +- .../declarative/checks/connection_checker.py | 13 +- .../concurrency_level/concurrency_level.py | 16 +- .../concurrent_declarative_source.py | 37 +- .../declarative/datetime/datetime_parser.py | 14 +- .../declarative/datetime/min_max_datetime.py | 25 +- .../sources/declarative/declarative_source.py | 13 +- .../sources/declarative/declarative_stream.py | 71 +-- .../sources/declarative/decoders/decoder.py | 15 +- .../declarative/decoders/json_decoder.py | 24 +- .../declarative/decoders/noop_decoder.py | 6 +- .../decoders/pagination_decoder_decorator.py | 10 +- .../declarative/decoders/xml_decoder.py | 9 +- airbyte_cdk/sources/declarative/exceptions.py | 5 +- .../declarative/extractors/dpath_extractor.py | 10 +- .../declarative/extractors/http_selector.py | 15 +- .../extractors/record_extractor.py | 12 +- .../declarative/extractors/record_filter.py | 33 +- .../declarative/extractors/record_selector.py | 35 +- .../extractors/response_to_file_extractor.py | 49 +- .../incremental/datetime_based_cursor.py | 105 ++-- .../incremental/declarative_cursor.py | 4 +- .../incremental/global_substream_cursor.py | 84 ++- .../incremental/per_partition_cursor.py | 65 +-- .../incremental/per_partition_with_global.py | 46 +- .../resumable_full_refresh_cursor.py | 50 +- .../declarative/interpolation/filters.py | 26 +- .../interpolation/interpolated_boolean.py | 33 +- .../interpolation/interpolated_mapping.py | 18 +- .../interpolated_nested_mapping.py | 20 +- .../interpolation/interpolated_string.py | 24 +- .../interpolation/interpolation.py | 12 +- .../declarative/interpolation/jinja.py | 43 +- .../declarative/interpolation/macros.py | 36 +- .../manifest_declarative_source.py | 63 +- ...legacy_to_per_partition_state_migration.py | 7 +- .../declarative/migrations/state_migration.py | 10 +- .../models/declarative_component_schema.py | 549 +++++++++--------- .../declarative/parsers/custom_exceptions.py | 9 +- .../parsers/manifest_component_transformer.py | 11 +- .../parsers/manifest_reference_resolver.py | 36 +- .../parsers/model_to_component_factory.py | 167 +++--- .../cartesian_product_stream_slicer.py | 49 +- .../list_partition_router.py | 53 +- .../partition_routers/partition_router.py | 18 +- .../single_partition_router.py | 38 +- .../substream_partition_router.py | 67 ++- .../constant_backoff_strategy.py | 14 +- .../exponential_backoff_strategy.py | 14 +- .../backoff_strategies/header_helper.py | 16 +- .../wait_time_from_header_backoff_strategy.py | 18 +- ...until_time_from_header_backoff_strategy.py | 20 +- .../error_handlers/backoff_strategy.py | 4 +- .../error_handlers/composite_error_handler.py | 16 +- .../error_handlers/default_error_handler.py | 26 +- .../default_http_response_filter.py | 16 +- .../error_handlers/error_handler.py | 4 +- .../error_handlers/http_response_filter.py | 49 +- .../requesters/http_job_repository.py | 46 +- .../declarative/requesters/http_requester.py | 139 +++-- .../paginators/default_paginator.py | 91 ++- .../requesters/paginators/no_pagination.py | 41 +- .../requesters/paginators/paginator.py | 26 +- .../strategies/cursor_pagination_strategy.py | 30 +- .../paginators/strategies/offset_increment.py | 33 +- .../paginators/strategies/page_increment.py | 25 +- .../strategies/pagination_strategy.py | 33 +- .../paginators/strategies/stop_condition.py | 17 +- .../declarative/requesters/request_option.py | 13 +- ...datetime_based_request_options_provider.py | 43 +- .../default_request_options_provider.py | 33 +- ...erpolated_nested_request_input_provider.py | 25 +- .../interpolated_request_input_provider.py | 29 +- .../interpolated_request_options_provider.py | 68 +-- .../request_options_provider.py | 42 +- .../declarative/requesters/request_path.py | 8 +- .../declarative/requesters/requester.py | 93 ++- .../declarative/retrievers/async_retriever.py | 31 +- .../declarative/retrievers/retriever.py | 15 +- .../retrievers/simple_retriever.py | 172 +++--- .../schema/default_schema_loader.py | 11 +- .../schema/inline_schema_loader.py | 6 +- .../schema/json_file_schema_loader.py | 18 +- .../declarative/schema/schema_loader.py | 4 +- airbyte_cdk/sources/declarative/spec/spec.py | 16 +- .../stream_slicers/stream_slicer.py | 9 +- .../declarative/transformations/add_fields.py | 43 +- .../keys_to_lower_transformation.py | 11 +- .../transformations/remove_fields.py | 21 +- .../transformations/transformation.py | 18 +- airbyte_cdk/sources/declarative/types.py | 1 + .../declarative/yaml_declarative_source.py | 23 +- .../sources/embedded/base_integration.py | 15 +- airbyte_cdk/sources/embedded/catalog.py | 13 +- airbyte_cdk/sources/embedded/runner.py | 9 +- airbyte_cdk/sources/embedded/tools.py | 9 +- ...stract_file_based_availability_strategy.py | 24 +- ...efault_file_based_availability_strategy.py | 29 +- .../config/abstract_file_based_spec.py | 38 +- .../sources/file_based/config/avro_format.py | 3 +- .../sources/file_based/config/csv_format.py | 32 +- .../sources/file_based/config/excel_format.py | 4 +- .../config/file_based_stream_config.py | 36 +- .../sources/file_based/config/jsonl_format.py | 4 +- .../file_based/config/parquet_format.py | 3 +- .../file_based/config/unstructured_format.py | 13 +- .../abstract_discovery_policy.py | 4 +- .../default_discovery_policy.py | 5 +- airbyte_cdk/sources/file_based/exceptions.py | 12 +- .../sources/file_based/file_based_source.py | 55 +- .../file_based/file_based_stream_reader.py | 48 +- .../file_based/file_types/avro_parser.py | 53 +- .../file_based/file_types/csv_parser.py | 86 ++- .../file_based/file_types/excel_parser.py | 50 +- .../file_based/file_types/file_transfer.py | 8 +- .../file_based/file_types/file_type_parser.py | 48 +- .../file_based/file_types/jsonl_parser.py | 31 +- .../file_based/file_types/parquet_parser.py | 96 ++- .../file_types/unstructured_parser.py | 94 ++- airbyte_cdk/sources/file_based/remote_file.py | 8 +- .../sources/file_based/schema_helpers.py | 74 ++- .../abstract_schema_validation_policy.py | 12 +- .../default_schema_validation_policies.py | 10 +- .../stream/abstract_file_based_stream.py | 80 ++- .../file_based/stream/concurrent/adapters.py | 63 +- .../abstract_concurrent_file_based_cursor.py | 7 +- .../cursor/file_based_concurrent_cursor.py | 75 +-- .../cursor/file_based_final_state_cursor.py | 9 +- .../cursor/abstract_file_based_cursor.py | 29 +- .../cursor/default_file_based_cursor.py | 45 +- .../stream/default_file_based_stream.py | 67 +-- airbyte_cdk/sources/file_based/types.py | 4 +- airbyte_cdk/sources/http_config.py | 3 + airbyte_cdk/sources/http_logger.py | 8 +- airbyte_cdk/sources/message/repository.py | 17 +- airbyte_cdk/sources/source.py | 22 +- .../sources/streams/availability_strategy.py | 26 +- airbyte_cdk/sources/streams/call_rate.py | 71 +-- .../streams/checkpoint/checkpoint_reader.py | 133 ++--- .../sources/streams/checkpoint/cursor.py | 31 +- .../per_partition_key_serializer.py | 7 +- .../resumable_full_refresh_cursor.py | 19 +- ...substream_resumable_full_refresh_cursor.py | 21 +- .../streams/concurrent/abstract_stream.py | 42 +- .../concurrent/abstract_stream_facade.py | 18 +- .../sources/streams/concurrent/adapters.py | 97 ++-- .../concurrent/availability_strategy.py | 31 +- .../sources/streams/concurrent/cursor.py | 75 +-- .../streams/concurrent/default_stream.py | 18 +- .../sources/streams/concurrent/exceptions.py | 5 +- .../sources/streams/concurrent/helpers.py | 27 +- .../streams/concurrent/partition_enqueuer.py | 12 +- .../streams/concurrent/partition_reader.py | 13 +- .../concurrent/partitions/partition.py | 29 +- .../partitions/partition_generator.py | 6 +- .../streams/concurrent/partitions/record.py | 13 +- .../streams/concurrent/partitions/types.py | 12 +- .../abstract_stream_state_converter.py | 51 +- .../datetime_stream_state_converter.py | 39 +- airbyte_cdk/sources/streams/core.py | 176 +++--- .../streams/http/availability_strategy.py | 12 +- .../http/error_handlers/backoff_strategy.py | 9 +- .../default_backoff_strategy.py | 8 +- .../error_handlers/default_error_mapping.py | 9 +- .../http/error_handlers/error_handler.py | 25 +- .../error_handlers/error_message_parser.py | 7 +- .../http_status_error_handler.py | 58 +- .../json_error_message_parser.py | 15 +- .../http/error_handlers/response_models.py | 17 +- .../sources/streams/http/exceptions.py | 21 +- airbyte_cdk/sources/streams/http/http.py | 269 ++++----- .../sources/streams/http/http_client.py | 114 ++-- .../sources/streams/http/rate_limiting.py | 17 +- .../requests_native_auth/abstract_oauth.py | 67 +-- .../requests_native_auth/abstract_token.py | 4 +- .../http/requests_native_auth/oauth.py | 77 ++- .../http/requests_native_auth/token.py | 15 +- airbyte_cdk/sources/types.py | 21 +- airbyte_cdk/sources/utils/casing.py | 2 +- airbyte_cdk/sources/utils/record_helper.py | 7 +- airbyte_cdk/sources/utils/schema_helpers.py | 60 +- airbyte_cdk/sources/utils/slice_logger.py | 20 +- airbyte_cdk/sources/utils/transform.py | 63 +- airbyte_cdk/sources/utils/types.py | 2 + airbyte_cdk/sql/_util/hashing.py | 1 + airbyte_cdk/sql/_util/name_normalizers.py | 1 + airbyte_cdk/sql/constants.py | 1 + airbyte_cdk/sql/exceptions.py | 1 + airbyte_cdk/sql/secrets.py | 8 +- airbyte_cdk/sql/shared/catalog_providers.py | 1 + airbyte_cdk/sql/shared/sql_processor.py | 16 +- airbyte_cdk/sql/types.py | 3 +- airbyte_cdk/test/catalog_builder.py | 25 +- airbyte_cdk/test/entrypoint_wrapper.py | 50 +- airbyte_cdk/test/mock_http/matcher.py | 4 +- airbyte_cdk/test/mock_http/mocker.py | 33 +- airbyte_cdk/test/mock_http/request.py | 30 +- airbyte_cdk/test/mock_http/response.py | 3 +- .../test/mock_http/response_builder.py | 87 ++- airbyte_cdk/test/state_builder.py | 11 +- airbyte_cdk/test/utils/data.py | 1 + airbyte_cdk/test/utils/http_mocking.py | 4 +- airbyte_cdk/test/utils/reading.py | 6 +- airbyte_cdk/utils/airbyte_secrets_utils.py | 27 +- airbyte_cdk/utils/analytics_message.py | 5 +- airbyte_cdk/utils/constants.py | 2 + airbyte_cdk/utils/datetime_format_inferrer.py | 17 +- airbyte_cdk/utils/event_timing.py | 22 +- airbyte_cdk/utils/is_cloud_environment.py | 5 +- airbyte_cdk/utils/mapping_helpers.py | 14 +- airbyte_cdk/utils/message_utils.py | 1 + airbyte_cdk/utils/oneof_option_config.py | 8 +- airbyte_cdk/utils/print_buffer.py | 13 +- airbyte_cdk/utils/schema_inferrer.py | 66 +-- .../utils/spec_schema_transformations.py | 4 +- airbyte_cdk/utils/stream_status_utils.py | 12 +- airbyte_cdk/utils/traced_exception.py | 43 +- bin/generate_component_manifest_files.py | 2 + docs/generate.py | 3 +- reference_docs/_source/conf.py | 3 + reference_docs/generate_rst_schema.py | 8 +- unit_tests/conftest.py | 3 +- .../test_connector_builder_handler.py | 25 +- .../connector_builder/test_message_grouper.py | 21 +- unit_tests/connector_builder/utils.py | 4 +- unit_tests/destinations/test_destination.py | 47 +- .../vector_db_based/config_test.py | 24 +- .../document_processor_test.py | 9 +- .../vector_db_based/embedder_test.py | 32 +- .../vector_db_based/writer_test.py | 14 +- .../test_concurrent_source_adapter.py | 15 +- unit_tests/sources/conftest.py | 6 +- .../declarative/async_job/test_integration.py | 18 +- .../sources/declarative/async_job/test_job.py | 5 +- .../async_job/test_job_orchestrator.py | 25 +- .../declarative/async_job/test_job_tracker.py | 6 +- .../sources/declarative/auth/test_jwt.py | 10 +- .../sources/declarative/auth/test_oauth.py | 28 +- .../auth/test_selective_authenticator.py | 2 + .../auth/test_session_token_auth.py | 5 +- .../declarative/auth/test_token_auth.py | 41 +- .../declarative/auth/test_token_provider.py | 4 +- .../declarative/checks/test_check_stream.py | 8 +- .../test_concurrency_level.py | 11 +- .../datetime/test_datetime_parser.py | 2 + .../datetime/test_min_max_datetime.py | 9 +- .../declarative/decoders/test_json_decoder.py | 3 + .../test_pagination_decoder_decorator.py | 3 + .../declarative/decoders/test_xml_decoder.py | 3 + .../sources/declarative/external_component.py | 5 +- .../extractors/test_dpath_extractor.py | 9 +- .../extractors/test_record_filter.py | 18 +- .../extractors/test_record_selector.py | 2 + .../test_response_to_file_extractor.py | 3 + .../incremental/test_datetime_based_cursor.py | 7 +- .../incremental/test_per_partition_cursor.py | 10 +- .../test_per_partition_cursor_integration.py | 14 +- .../test_resumable_full_refresh_cursor.py | 2 + .../declarative/interpolation/test_filters.py | 4 + .../test_interpolated_boolean.py | 3 + .../test_interpolated_mapping.py | 4 +- .../test_interpolated_nested_mapping.py | 4 +- .../interpolation/test_interpolated_string.py | 5 +- .../declarative/interpolation/test_jinja.py | 21 +- .../declarative/interpolation/test_macros.py | 6 +- .../test_legacy_to_per_partition_migration.py | 11 +- .../test_manifest_component_transformer.py | 2 + .../test_manifest_reference_resolver.py | 3 + .../test_model_to_component_factory.py | 21 +- .../declarative/parsers/testing_components.py | 14 +- ...test_cartesian_product_partition_router.py | 10 +- .../test_list_partition_router.py | 13 +- .../test_parent_state_stream.py | 31 +- .../test_single_partition_router.py | 1 + .../test_substream_partition_router.py | 25 +- .../test_constant_backoff.py | 3 + .../test_exponential_backoff.py | 3 + .../backoff_strategies/test_header_helper.py | 14 +- .../test_wait_time_from_header.py | 5 +- .../test_wait_until_time_from_header.py | 15 +- .../test_composite_error_handler.py | 3 + .../test_default_error_handler.py | 3 + .../test_default_http_response_filter.py | 4 +- .../test_http_response_filter.py | 2 + .../test_cursor_pagination_strategy.py | 2 + .../paginators/test_default_paginator.py | 2 + .../paginators/test_no_paginator.py | 2 + .../paginators/test_offset_increment.py | 8 +- .../paginators/test_page_increment.py | 6 +- .../paginators/test_request_option.py | 2 + .../paginators/test_stop_condition.py | 5 +- ...datetime_based_request_options_provider.py | 2 + ...t_interpolated_request_options_provider.py | 3 + .../requesters/test_http_job_repository.py | 4 +- .../requesters/test_http_requester.py | 29 +- ...est_interpolated_request_input_provider.py | 2 + .../retrievers/test_simple_retriever.py | 3 + .../schema/source_test/SourceTest.py | 1 + .../schema/test_default_schema_loader.py | 2 + .../schema/test_inline_schema_loader.py | 2 + .../schema/test_json_file_schema_loader.py | 5 +- .../sources/declarative/spec/test_spec.py | 2 + .../test_concurrent_declarative_source.py | 80 +-- .../declarative/test_declarative_stream.py | 7 +- .../test_manifest_declarative_source.py | 19 +- unit_tests/sources/declarative/test_types.py | 2 + .../test_yaml_declarative_source.py | 13 +- .../transformations/test_add_fields.py | 9 +- .../test_keys_to_lower_transformation.py | 2 + .../transformations/test_remove_fields.py | 7 +- .../embedded/test_embedded_integration.py | 6 +- ...efault_file_based_availability_strategy.py | 20 +- .../config/test_abstract_file_based_spec.py | 10 +- .../file_based/config/test_csv_format.py | 4 +- .../config/test_file_based_stream_config.py | 9 +- .../test_default_discovery_policy.py | 9 +- .../file_based/file_types/test_avro_parser.py | 6 +- .../file_based/file_types/test_csv_parser.py | 26 +- .../file_types/test_excel_parser.py | 3 +- .../file_types/test_jsonl_parser.py | 13 +- .../file_types/test_parquet_parser.py | 18 +- .../file_types/test_unstructured_parser.py | 7 +- unit_tests/sources/file_based/helpers.py | 21 +- .../file_based/in_memory_files_source.py | 60 +- .../file_based/scenarios/avro_scenarios.py | 2 + .../file_based/scenarios/check_scenarios.py | 5 +- .../concurrent_incremental_scenarios.py | 7 +- .../file_based/scenarios/csv_scenarios.py | 13 +- .../file_based/scenarios/excel_scenarios.py | 2 + .../scenarios/file_based_source_builder.py | 49 +- .../scenarios/incremental_scenarios.py | 7 +- .../file_based/scenarios/jsonl_scenarios.py | 9 +- .../file_based/scenarios/parquet_scenarios.py | 11 +- .../file_based/scenarios/scenario_builder.py | 126 ++-- .../scenarios/unstructured_scenarios.py | 7 +- .../scenarios/user_input_schema_scenarios.py | 6 +- .../scenarios/validation_policy_scenarios.py | 6 +- .../test_default_schema_validation_policy.py | 6 +- .../stream/concurrent/test_adapters.py | 6 +- .../test_file_based_concurrent_cursor.py | 35 +- .../stream/test_default_file_based_cursor.py | 15 +- .../stream/test_default_file_based_stream.py | 12 +- .../file_based/test_file_based_scenarios.py | 6 +- .../test_file_based_stream_reader.py | 23 +- .../sources/file_based/test_scenarios.py | 34 +- .../sources/file_based/test_schema_helpers.py | 12 +- .../sources/fixtures/source_test_fixture.py | 33 +- unit_tests/sources/message/test_repository.py | 9 +- .../mock_server_tests/mock_source_fixture.py | 95 ++- .../airbyte_message_assertions.py | 8 +- .../test_mock_server_abstract_source.py | 43 +- .../test_resumable_full_refresh.py | 24 +- .../checkpoint/test_checkpoint_reader.py | 2 + ...substream_resumable_full_refresh_cursor.py | 10 +- .../scenarios/incremental_scenarios.py | 16 +- .../scenarios/stream_facade_builder.py | 44 +- .../scenarios/stream_facade_scenarios.py | 8 +- .../scenarios/test_concurrent_scenarios.py | 3 + ...hread_based_concurrent_stream_scenarios.py | 15 +- ..._based_concurrent_stream_source_builder.py | 41 +- .../streams/concurrent/scenarios/utils.py | 23 +- .../streams/concurrent/test_adapters.py | 4 + .../test_concurrent_read_processor.py | 8 +- .../sources/streams/concurrent/test_cursor.py | 18 +- .../test_datetime_state_converter.py | 6 +- .../streams/concurrent/test_default_stream.py | 2 + .../concurrent/test_partition_enqueuer.py | 15 +- .../concurrent/test_partition_reader.py | 10 +- .../concurrent/test_thread_pool_manager.py | 2 + .../test_default_backoff_strategy.py | 9 +- .../test_http_status_error_handler.py | 4 + .../test_json_error_message_parser.py | 2 + .../error_handlers/test_response_models.py | 3 + .../test_requests_native_auth.py | 68 +-- .../http/test_availability_strategy.py | 15 +- unit_tests/sources/streams/http/test_http.py | 159 +++-- .../sources/streams/http/test_http_client.py | 18 +- unit_tests/sources/streams/test_call_rate.py | 12 +- .../sources/streams/test_stream_read.py | 37 +- .../sources/streams/test_streams_core.py | 130 ++--- .../streams/utils/test_stream_helper.py | 2 + unit_tests/sources/test_abstract_source.py | 134 ++--- unit_tests/sources/test_config.py | 10 +- .../sources/test_connector_state_manager.py | 5 +- unit_tests/sources/test_http_logger.py | 11 +- unit_tests/sources/test_integration_source.py | 18 +- unit_tests/sources/test_source.py | 17 +- unit_tests/sources/test_source_read.py | 37 +- .../sources/utils/test_record_helper.py | 3 + .../sources/utils/test_schema_helpers.py | 10 +- unit_tests/sources/utils/test_slice_logger.py | 6 +- unit_tests/sources/utils/test_transform.py | 3 + unit_tests/test/mock_http/test_matcher.py | 1 + unit_tests/test/mock_http/test_mocker.py | 3 + unit_tests/test/mock_http/test_request.py | 2 + .../test/mock_http/test_response_builder.py | 22 +- unit_tests/test/test_entrypoint_wrapper.py | 13 +- unit_tests/test_config_observation.py | 2 + unit_tests/test_connector.py | 11 +- unit_tests/test_counter.py | 2 +- unit_tests/test_entrypoint.py | 16 +- unit_tests/test_exception_handler.py | 3 +- unit_tests/test_logger.py | 5 +- unit_tests/test_secure_logger.py | 8 +- .../utils/test_datetime_format_inferrer.py | 7 +- unit_tests/utils/test_mapping_helpers.py | 2 + unit_tests/utils/test_message_utils.py | 2 + unit_tests/utils/test_rate_limiting.py | 4 +- unit_tests/utils/test_schema_inferrer.py | 9 +- unit_tests/utils/test_secret_utils.py | 3 + unit_tests/utils/test_stream_status_utils.py | 2 + unit_tests/utils/test_traced_exception.py | 8 +- 454 files changed, 5509 insertions(+), 5833 deletions(-) diff --git a/airbyte_cdk/config_observation.py b/airbyte_cdk/config_observation.py index 764174f0..4a80a8e9 100644 --- a/airbyte_cdk/config_observation.py +++ b/airbyte_cdk/config_observation.py @@ -7,8 +7,11 @@ ) import time +from collections.abc import MutableMapping from copy import copy -from typing import Any, List, MutableMapping +from typing import Any + +from orjson import orjson from airbyte_cdk.models import ( AirbyteControlConnectorConfigMessage, @@ -18,7 +21,6 @@ OrchestratorType, Type, ) -from orjson import orjson class ObservedDict(dict): # type: ignore # disallow_any_generics is set to True, and dict is equivalent to dict[Any] @@ -37,7 +39,7 @@ def __init__( non_observed_mapping[item] = ObservedDict(value, observer) # Observe nested list of dicts - if isinstance(value, List): + if isinstance(value, list): for i, sub_value in enumerate(value): if isinstance(sub_value, MutableMapping): value[i] = ObservedDict(sub_value, observer) @@ -51,7 +53,7 @@ def __setitem__(self, item: Any, value: Any) -> None: previous_value = self.get(item) if isinstance(value, MutableMapping): value = ObservedDict(value, self.observer) - if isinstance(value, List): + if isinstance(value, list): for i, sub_value in enumerate(value): if isinstance(sub_value, MutableMapping): value[i] = ObservedDict(sub_value, self.observer) @@ -86,8 +88,7 @@ def observe_connector_config( def emit_configuration_as_airbyte_control_message(config: MutableMapping[str, Any]) -> None: - """ - WARNING: deprecated - emit_configuration_as_airbyte_control_message is being deprecated in favor of the MessageRepository mechanism. + """WARNING: deprecated - emit_configuration_as_airbyte_control_message is being deprecated in favor of the MessageRepository mechanism. See the airbyte_cdk.sources.message package """ airbyte_message = create_connector_config_control_message(config) diff --git a/airbyte_cdk/connector.py b/airbyte_cdk/connector.py index 29cfc968..9b0da16e 100644 --- a/airbyte_cdk/connector.py +++ b/airbyte_cdk/connector.py @@ -1,16 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import json import logging import os import pkgutil from abc import ABC, abstractmethod -from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar +from collections.abc import Mapping +from typing import Any, Generic, Protocol, TypeVar import yaml + from airbyte_cdk.models import ( AirbyteConnectionStatus, ConnectorSpecification, @@ -18,7 +20,7 @@ ) -def load_optional_package_file(package: str, filename: str) -> Optional[bytes]: +def load_optional_package_file(package: str, filename: str) -> bytes | None: """Gets a resource from a package, returning None if it does not exist""" try: return pkgutil.get_data(package, filename) @@ -35,23 +37,20 @@ class BaseConnector(ABC, Generic[TConfig]): @abstractmethod def configure(self, config: Mapping[str, Any], temp_dir: str) -> TConfig: - """ - Persist config in temporary directory to run the Source job - """ + """Persist config in temporary directory to run the Source job""" @staticmethod def read_config(config_path: str) -> Mapping[str, Any]: config = BaseConnector._read_json_file(config_path) if isinstance(config, Mapping): return config - else: - raise ValueError( - f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." - ) + raise ValueError( + f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." + ) @staticmethod def _read_json_file(file_path: str) -> Any: - with open(file_path, "r") as file: + with open(file_path) as file: contents = file.read() try: @@ -67,11 +66,9 @@ def write_config(config: TConfig, config_path: str) -> None: fh.write(json.dumps(config)) def spec(self, logger: logging.Logger) -> ConnectorSpecification: - """ - Returns the spec for this integration. The spec is a JSON-Schema object describing the required configurations (e.g: username and password) + """Returns the spec for this integration. The spec is a JSON-Schema object describing the required configurations (e.g: username and password) required to run this integration. By default, this will be loaded from a "spec.yaml" or a "spec.json" in the package root. """ - package = self.__class__.__module__.split(".")[0] yaml_spec = load_optional_package_file(package, "spec.yaml") @@ -98,8 +95,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: @abstractmethod def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionStatus: - """ - Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect + """Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect to the Stripe API. """ diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index 44d1bfe1..e90be3ff 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import dataclasses +from collections.abc import Mapping from datetime import datetime -from typing import Any, List, Mapping +from typing import Any from airbyte_cdk.connector_builder.message_grouper import MessageGrouper from airbyte_cdk.models import ( @@ -12,8 +14,8 @@ AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, + Type, ) -from airbyte_cdk.models import Type from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource @@ -23,6 +25,7 @@ from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 DEFAULT_MAXIMUM_RECORDS = 100 @@ -68,7 +71,7 @@ def read_stream( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], + state: list[AirbyteStateMessage], limits: TestReadLimits, ) -> AirbyteMessage: try: @@ -89,7 +92,7 @@ def read_stream( error = AirbyteTracedException.from_exception( exc, message=filter_secrets( - f"Error reading stream with config={config} and catalog={configured_catalog}: {str(exc)}" + f"Error reading stream with config={config} and catalog={configured_catalog}: {exc!s}" ), ) return error.as_airbyte_message() @@ -107,7 +110,7 @@ def resolve_manifest(source: ManifestDeclarativeSource) -> AirbyteMessage: ) except Exception as exc: error = AirbyteTracedException.from_exception( - exc, message=f"Error resolving manifest: {str(exc)}" + exc, message=f"Error resolving manifest: {exc!s}" ) return error.as_airbyte_message() diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index 35ba7e46..d110bf2e 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import sys -from typing import Any, List, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any + +from orjson import orjson from airbyte_cdk.connector import BaseConnector from airbyte_cdk.connector_builder.connector_builder_handler import ( @@ -25,12 +28,11 @@ from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.source import Source from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from orjson import orjson def get_config_and_catalog_from_args( - args: List[str], -) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]: + args: list[str], +) -> tuple[str, Mapping[str, Any], ConfiguredAirbyteCatalog | None, Any]: # TODO: Add functionality for the `debug` logger. # Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`. parsed_args = AirbyteEntrypoint.parse_args(args) @@ -69,22 +71,21 @@ def handle_connector_builder_request( source: ManifestDeclarativeSource, command: str, config: Mapping[str, Any], - catalog: Optional[ConfiguredAirbyteCatalog], - state: List[AirbyteStateMessage], + catalog: ConfiguredAirbyteCatalog | None, + state: list[AirbyteStateMessage], limits: TestReadLimits, ) -> AirbyteMessage: if command == "resolve_manifest": return resolve_manifest(source) - elif command == "test_read": + if command == "test_read": assert ( catalog is not None ), "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None." return read_stream(source, config, catalog, state, limits) - else: - raise ValueError(f"Unrecognized command {command}.") + raise ValueError(f"Unrecognized command {command}.") -def handle_request(args: List[str]) -> str: +def handle_request(args: list[str]) -> str: command, config, catalog, state = get_config_and_catalog_from_args(args) limits = get_limits(config) source = create_source(config, limits) @@ -100,7 +101,7 @@ def handle_request(args: List[str]) -> str: print(handle_request(sys.argv[1:])) except Exception as exc: error = AirbyteTracedException.from_exception( - exc, message=f"Error handling request: {str(exc)}" + exc, message=f"Error handling request: {exc!s}" ) m = error.as_airbyte_message() print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode()) diff --git a/airbyte_cdk/connector_builder/message_grouper.py b/airbyte_cdk/connector_builder/message_grouper.py index aa3a4293..c9fae791 100644 --- a/airbyte_cdk/connector_builder/message_grouper.py +++ b/airbyte_cdk/connector_builder/message_grouper.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging +from collections.abc import Iterable, Iterator, Mapping from copy import deepcopy from json import JSONDecodeError -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union +from typing import Any from airbyte_cdk.connector_builder.models import ( AuxiliaryRequest, @@ -46,8 +48,8 @@ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: self._max_record_limit = max_record_limit def _pk_to_nested_and_composite_field( - self, field: Optional[Union[str, List[str], List[List[str]]]] - ) -> List[List[str]]: + self, field: str | list[str] | list[list[str]] | None + ) -> list[list[str]]: if not field: return [[]] @@ -61,8 +63,8 @@ def _pk_to_nested_and_composite_field( return field # type: ignore # the type of field is expected to be List[List[str]] here def _cursor_field_to_nested_and_composite_field( - self, field: Union[str, List[str]] - ) -> List[List[str]]: + self, field: str | list[str] + ) -> list[list[str]]: if not field: return [[]] @@ -80,8 +82,8 @@ def get_message_groups( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], - record_limit: Optional[int] = None, + state: list[AirbyteStateMessage], + record_limit: int | None = None, ) -> StreamRead: if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): raise ValueError( @@ -113,20 +115,16 @@ def get_message_groups( ): if isinstance(message_group, AirbyteLogMessage): log_messages.append( - LogMessage( - **{"message": message_group.message, "level": message_group.level.value} - ) + LogMessage(message=message_group.message, level=message_group.level.value) ) elif isinstance(message_group, AirbyteTraceMessage): if message_group.type == TraceType.ERROR: log_messages.append( LogMessage( - **{ - "message": message_group.error.message, - "level": "ERROR", - "internal_message": message_group.error.internal_message, - "stacktrace": message_group.error.stack_trace, - } + message=message_group.error.message, + level="ERROR", + internal_message=message_group.error.internal_message, + stacktrace=message_group.error.stack_trace, ) ) elif isinstance(message_group, AirbyteControlMessage): @@ -170,16 +168,13 @@ def _get_message_groups( datetime_format_inferrer: DatetimeFormatInferrer, limit: int, ) -> Iterable[ - Union[ - StreamReadPages, - AirbyteControlMessage, - AirbyteLogMessage, - AirbyteTraceMessage, - AuxiliaryRequest, - ] + StreamReadPages + | AirbyteControlMessage + | AirbyteLogMessage + | AirbyteTraceMessage + | AuxiliaryRequest ]: - """ - Message groups are partitioned according to when request log messages are received. Subsequent response log messages + """Message groups are partitioned according to when request log messages are received. Subsequent response log messages and record messages belong to the prior request log message and when we encounter another request, append the latest message group, until records have been read. @@ -195,12 +190,12 @@ def _get_message_groups( """ records_count = 0 at_least_one_page_in_group = False - current_page_records: List[Mapping[str, Any]] = [] - current_slice_descriptor: Optional[Dict[str, Any]] = None - current_slice_pages: List[StreamReadPages] = [] - current_page_request: Optional[HttpRequest] = None - current_page_response: Optional[HttpResponse] = None - latest_state_message: Optional[Dict[str, Any]] = None + current_page_records: list[Mapping[str, Any]] = [] + current_slice_descriptor: dict[str, Any] | None = None + current_slice_pages: list[StreamReadPages] = [] + current_page_request: HttpRequest | None = None + current_page_response: HttpResponse | None = None + latest_state_message: dict[str, Any] | None = None while records_count < limit and (message := next(messages, None)): json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None @@ -208,7 +203,7 @@ def _get_message_groups( raise ValueError( f"Expected log message to be a dict, got {json_object} of type {type(json_object)}" ) - json_message: Optional[Dict[str, JsonType]] = json_object + json_message: dict[str, JsonType] | None = json_object if self._need_to_close_page(at_least_one_page_in_group, message, json_message): self._close_page( current_page_request, @@ -285,25 +280,24 @@ def _get_message_groups( yield message.control elif message.type == MessageType.STATE: latest_state_message = message.state # type: ignore[assignment] - else: - if current_page_request or current_page_response or current_page_records: - self._close_page( - current_page_request, - current_page_response, - current_slice_pages, - current_page_records, - ) - yield StreamReadSlices( - pages=current_slice_pages, - slice_descriptor=current_slice_descriptor, - state=[latest_state_message] if latest_state_message else [], - ) + if current_page_request or current_page_response or current_page_records: + self._close_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) + yield StreamReadSlices( + pages=current_slice_pages, + slice_descriptor=current_slice_descriptor, + state=[latest_state_message] if latest_state_message else [], + ) @staticmethod def _need_to_close_page( at_least_one_page_in_group: bool, message: AirbyteMessage, - json_message: Optional[Dict[str, Any]], + json_message: dict[str, Any] | None, ) -> bool: return ( at_least_one_page_in_group @@ -315,22 +309,20 @@ def _need_to_close_page( ) @staticmethod - def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool: + def _is_page_http_request(json_message: dict[str, Any] | None) -> bool: if not json_message: return False - else: - return MessageGrouper._is_http_log( - json_message - ) and not MessageGrouper._is_auxiliary_http_request(json_message) + return MessageGrouper._is_http_log( + json_message + ) and not MessageGrouper._is_auxiliary_http_request(json_message) @staticmethod - def _is_http_log(message: Dict[str, JsonType]) -> bool: + def _is_http_log(message: dict[str, JsonType]) -> bool: return bool(message.get("http", False)) @staticmethod - def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: - """ - A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried. + def _is_auxiliary_http_request(message: dict[str, Any] | None) -> bool: + """A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried. A couple of examples are: * OAuth authentication * Substream slice generation @@ -343,14 +335,12 @@ def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: @staticmethod def _close_page( - current_page_request: Optional[HttpRequest], - current_page_response: Optional[HttpResponse], - current_slice_pages: List[StreamReadPages], - current_page_records: List[Mapping[str, Any]], + current_page_request: HttpRequest | None, + current_page_response: HttpResponse | None, + current_slice_pages: list[StreamReadPages], + current_page_records: list[Mapping[str, Any]], ) -> None: - """ - Close a page when parsing message groups - """ + """Close a page when parsing message groups""" current_slice_pages.append( StreamReadPages( request=current_page_request, @@ -365,7 +355,7 @@ def _read_stream( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], + state: list[AirbyteStateMessage], ) -> Iterator[AirbyteMessage]: # the generator can raise an exception # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage @@ -403,7 +393,7 @@ def _parse_json(log_message: AirbyteLogMessage) -> JsonType: return None @staticmethod - def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest: + def _create_request_from_log_message(json_http_message: dict[str, Any]) -> HttpRequest: url = json_http_message.get("url", {}).get("full", "") request = json_http_message.get("http", {}).get("request", {}) return HttpRequest( @@ -414,14 +404,14 @@ def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpR ) @staticmethod - def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse: + def _create_response_from_log_message(json_http_message: dict[str, Any]) -> HttpResponse: response = json_http_message.get("http", {}).get("response", {}) body = response.get("body", {}).get("content", "") return HttpResponse( status=response.get("status_code"), body=body, headers=response.get("headers") ) - def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: + def _has_reached_limit(self, slices: list[StreamReadSlices]) -> bool: if len(slices) >= self._max_slices: return True @@ -436,13 +426,13 @@ def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: return True return False - def _parse_slice_description(self, log_message: str) -> Dict[str, Any]: + def _parse_slice_description(self, log_message: str) -> dict[str, Any]: return json.loads(log_message.replace(SliceLogger.SLICE_LOG_PREFIX, "", 1)) # type: ignore @staticmethod - def _clean_config(config: Dict[str, Any]) -> Dict[str, Any]: + def _clean_config(config: dict[str, Any]) -> dict[str, Any]: cleaned_config = deepcopy(config) - for key in config.keys(): + for key in config: if key.startswith("__"): del cleaned_config[key] return cleaned_config diff --git a/airbyte_cdk/connector_builder/models.py b/airbyte_cdk/connector_builder/models.py index 50eb8eb9..bd46cee3 100644 --- a/airbyte_cdk/connector_builder/models.py +++ b/airbyte_cdk/connector_builder/models.py @@ -1,46 +1,47 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass class HttpResponse: status: int - body: Optional[str] = None - headers: Optional[Dict[str, Any]] = None + body: str | None = None + headers: dict[str, Any] | None = None @dataclass class HttpRequest: url: str - headers: Optional[Dict[str, Any]] + headers: dict[str, Any] | None http_method: str - body: Optional[str] = None + body: str | None = None @dataclass class StreamReadPages: - records: List[object] - request: Optional[HttpRequest] = None - response: Optional[HttpResponse] = None + records: list[object] + request: HttpRequest | None = None + response: HttpResponse | None = None @dataclass class StreamReadSlices: - pages: List[StreamReadPages] - slice_descriptor: Optional[Dict[str, Any]] - state: Optional[List[Dict[str, Any]]] = None + pages: list[StreamReadPages] + slice_descriptor: dict[str, Any] | None + state: list[dict[str, Any]] | None = None @dataclass class LogMessage: message: str level: str - internal_message: Optional[str] = None - stacktrace: Optional[str] = None + internal_message: str | None = None + stacktrace: str | None = None @dataclass @@ -52,20 +53,20 @@ class AuxiliaryRequest: @dataclass -class StreamRead(object): - logs: List[LogMessage] - slices: List[StreamReadSlices] +class StreamRead: + logs: list[LogMessage] + slices: list[StreamReadSlices] test_read_limit_reached: bool - auxiliary_requests: List[AuxiliaryRequest] - inferred_schema: Optional[Dict[str, Any]] - inferred_datetime_formats: Optional[Dict[str, str]] - latest_config_update: Optional[Dict[str, Any]] + auxiliary_requests: list[AuxiliaryRequest] + inferred_schema: dict[str, Any] | None + inferred_datetime_formats: dict[str, str] | None + latest_config_update: dict[str, Any] | None @dataclass class StreamReadRequestBody: - manifest: Dict[str, Any] + manifest: dict[str, Any] stream: str - config: Dict[str, Any] - state: Optional[Dict[str, Any]] - record_limit: Optional[int] + config: dict[str, Any] + state: dict[str, Any] | None + record_limit: int | None diff --git a/airbyte_cdk/destinations/destination.py b/airbyte_cdk/destinations/destination.py index febf4a1b..74644c0c 100644 --- a/airbyte_cdk/destinations/destination.py +++ b/airbyte_cdk/destinations/destination.py @@ -1,13 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import argparse import io import logging import sys from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Mapping +from collections.abc import Iterable, Mapping +from typing import Any + +from orjson import orjson from airbyte_cdk.connector import Connector from airbyte_cdk.exception_handler import init_uncaught_exception_handler @@ -20,7 +24,7 @@ ) from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from orjson import orjson + logger = logging.getLogger("airbyte") @@ -67,12 +71,10 @@ def _run_write( ) logger.info("Writing complete.") - def parse_args(self, args: List[str]) -> argparse.Namespace: - """ - :param args: commandline arguments + def parse_args(self, args: list[str]) -> argparse.Namespace: + """:param args: commandline arguments :return: """ - parent_parser = argparse.ArgumentParser(add_help=False) main_parser = argparse.ArgumentParser() subparsers = main_parser.add_subparsers(title="commands", dest="command") @@ -107,7 +109,7 @@ def parse_args(self, args: List[str]) -> argparse.Namespace: cmd = parsed_args.command if not cmd: raise Exception("No command entered. ") - elif cmd not in ["spec", "check", "write"]: + if cmd not in ["spec", "check", "write"]: # This is technically dead code since parse_args() would fail if this was the case # But it's non-obvious enough to warrant placing it here anyways raise Exception(f"Unknown command entered: {cmd}") @@ -145,7 +147,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: input_stream=wrapped_stdin, ) - def run(self, args: List[str]) -> None: + def run(self, args: list[str]) -> None: init_uncaught_exception_handler(logger) parsed_args = self.parse_args(args) output_messages = self.run_cmd(parsed_args) diff --git a/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte_cdk/destinations/vector_db_based/config.py index 904f40d3..2e2b3d9d 100644 --- a/airbyte_cdk/destinations/vector_db_based/config.py +++ b/airbyte_cdk/destinations/vector_db_based/config.py @@ -1,18 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Union import dpath +from pydantic.v1 import BaseModel, Field + from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig from airbyte_cdk.utils.spec_schema_transformations import resolve_refs -from pydantic.v1 import BaseModel, Field class SeparatorSplitterConfigModel(BaseModel): mode: Literal["separator"] = Field("separator", const=True) - separators: List[str] = Field( + separators: list[str] = Field( default=['"\\n\\n"', '"\\n"', '" "', '""'], title="Separators", description='List of separator strings to split text fields by. The separator itself needs to be wrapped in double quotes, e.g. to split by the dot character, use ".". To split by a newline, use "\\n".', @@ -101,14 +103,14 @@ class ProcessingConfigModel(BaseModel): description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context", default=0, ) - text_fields: Optional[List[str]] = Field( + text_fields: list[str] | None = Field( default=[], title="Text fields to embed", description="List of fields in the record that should be used to calculate the embedding. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.", always_show=True, examples=["text", "user.name", "users.*.name"], ) - metadata_fields: Optional[List[str]] = Field( + metadata_fields: list[str] | None = Field( default=[], title="Fields to store as metadata", description="List of fields in the record that should be stored as metadata. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered metadata fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array. When specifying nested paths, all matching values are flattened into an array set to a field named by the path.", @@ -122,7 +124,7 @@ class ProcessingConfigModel(BaseModel): type="object", description="Split text fields into chunks based on the specified method.", ) - field_name_mappings: Optional[List[FieldNameMappingConfigModel]] = Field( + field_name_mappings: list[FieldNameMappingConfigModel] | None = Field( default=[], title="Field name mappings", description="List of fields to rename. Not applicable for nested fields, but can be used to rename fields already flattened via dot notation.", @@ -237,8 +239,7 @@ class Config(OneOfOptionConfig): class VectorDBConfigModel(BaseModel): - """ - The configuration model for the Vector DB based destinations. This model is used to generate the UI for the destination configuration, + """The configuration model for the Vector DB based destinations. This model is used to generate the UI for the destination configuration, as well as to provide type safety for the configuration passed to the destination. The configuration model is composed of four parts: @@ -250,13 +251,13 @@ class VectorDBConfigModel(BaseModel): Processing, embedding and advanced configuration are provided by this base class, while the indexing configuration is provided by the destination connector in the sub class. """ - embedding: Union[ - OpenAIEmbeddingConfigModel, - CohereEmbeddingConfigModel, - FakeEmbeddingConfigModel, - AzureOpenAIEmbeddingConfigModel, - OpenAICompatibleEmbeddingConfigModel, - ] = Field( + embedding: ( + OpenAIEmbeddingConfigModel + | CohereEmbeddingConfigModel + | FakeEmbeddingConfigModel + | AzureOpenAIEmbeddingConfigModel + | OpenAICompatibleEmbeddingConfigModel + ) = Field( ..., title="Embedding", description="Embedding configuration", @@ -284,14 +285,14 @@ class Config: } @staticmethod - def remove_discriminator(schema: Dict[str, Any]) -> None: - """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" + def remove_discriminator(schema: dict[str, Any]) -> None: + """Pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" dpath.delete(schema, "properties/**/discriminator") @classmethod - def schema(cls, by_alias: bool = True, ref_template: str = "") -> Dict[str, Any]: - """we're overriding the schema classmethod to enable some post-processing""" - schema: Dict[str, Any] = super().schema() + def schema(cls, by_alias: bool = True, ref_template: str = "") -> dict[str, Any]: + """We're overriding the schema classmethod to enable some post-processing""" + schema: dict[str, Any] = super().schema() schema = resolve_refs(schema) cls.remove_discriminator(schema) return schema diff --git a/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte_cdk/destinations/vector_db_based/document_processor.py index 6e1723cb..983d90d6 100644 --- a/airbyte_cdk/destinations/vector_db_based/document_processor.py +++ b/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -1,13 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any import dpath +from langchain.text_splitter import Language, RecursiveCharacterTextSplitter +from langchain.utils import stringify_dict +from langchain_core.documents.base import Document + from airbyte_cdk.destinations.vector_db_based.config import ( ProcessingConfigModel, SeparatorSplitterConfigModel, @@ -21,9 +27,7 @@ DestinationSyncMode, ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType -from langchain.text_splitter import Language, RecursiveCharacterTextSplitter -from langchain.utils import stringify_dict -from langchain_core.documents.base import Document + METADATA_STREAM_FIELD = "_ab_stream" METADATA_RECORD_ID_FIELD = "_ab_record_id" @@ -33,10 +37,10 @@ @dataclass class Chunk: - page_content: Optional[str] - metadata: Dict[str, Any] + page_content: str | None + metadata: dict[str, Any] record: AirbyteRecordMessage - embedding: Optional[List[float]] = None + embedding: list[float] | None = None headers_to_split_on = [ @@ -50,8 +54,7 @@ class Chunk: class DocumentProcessor: - """ - DocumentProcessor is a helper class that generates documents from Airbyte records. + """DocumentProcessor is a helper class that generates documents from Airbyte records. It is used to generate documents from records before writing them to the destination: * The text fields are extracted from the record and concatenated to a single string. @@ -68,7 +71,7 @@ class DocumentProcessor: streams: Mapping[str, ConfiguredAirbyteStream] @staticmethod - def check_config(config: ProcessingConfigModel) -> Optional[str]: + def check_config(config: ProcessingConfigModel) -> str | None: if config.text_splitter is not None and config.text_splitter.mode == "separator": for s in config.text_splitter.separators: try: @@ -83,7 +86,7 @@ def _get_text_splitter( self, chunk_size: int, chunk_overlap: int, - splitter_config: Optional[TextSplitterConfigModel], + splitter_config: TextSplitterConfigModel | None, ) -> RecursiveCharacterTextSplitter: if splitter_config is None: splitter_config = SeparatorSplitterConfigModel(mode="separator") @@ -127,13 +130,12 @@ def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCata self.field_name_mappings = config.field_name_mappings self.logger = logging.getLogger("airbyte.document_processor") - def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[str]]: - """ - Generate documents from records. + def process(self, record: AirbyteRecordMessage) -> tuple[list[Chunk], str | None]: + """Generate documents from records. :param records: List of AirbyteRecordMessages :return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store) """ - if CDC_DELETED_FIELD in record.data and record.data[CDC_DELETED_FIELD]: + if record.data.get(CDC_DELETED_FIELD): return [], self._extract_primary_key(record) doc = self._generate_document(record) if doc is None: @@ -158,7 +160,7 @@ def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[s ) return chunks, id_to_delete - def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]: + def _generate_document(self, record: AirbyteRecordMessage) -> Document | None: relevant_fields = self._extract_relevant_fields(record, self.text_fields) if len(relevant_fields) == 0: return None @@ -167,8 +169,8 @@ def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document] return Document(page_content=text, metadata=metadata) def _extract_relevant_fields( - self, record: AirbyteRecordMessage, fields: Optional[List[str]] - ) -> Dict[str, Any]: + self, record: AirbyteRecordMessage, fields: list[str] | None + ) -> dict[str, Any]: relevant_fields = {} if fields and len(fields) > 0: for field in fields: @@ -179,7 +181,7 @@ def _extract_relevant_fields( relevant_fields = record.data return self._remap_field_names(relevant_fields) - def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]: + def _extract_metadata(self, record: AirbyteRecordMessage) -> dict[str, Any]: metadata = self._extract_relevant_fields(record, self.metadata_fields) metadata[METADATA_STREAM_FIELD] = create_stream_identifier(record) primary_key = self._extract_primary_key(record) @@ -187,7 +189,7 @@ def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]: metadata[METADATA_RECORD_ID_FIELD] = primary_key return metadata - def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]: + def _extract_primary_key(self, record: AirbyteRecordMessage) -> str | None: stream_identifier = create_stream_identifier(record) current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier] # if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones @@ -206,11 +208,11 @@ def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]: stringified_primary_key = "_".join(primary_key) return f"{stream_identifier}_{stringified_primary_key}" - def _split_document(self, doc: Document) -> List[Document]: - chunks: List[Document] = self.splitter.split_documents([doc]) + def _split_document(self, doc: Document) -> list[Document]: + chunks: list[Document] = self.splitter.split_documents([doc]) return chunks - def _remap_field_names(self, fields: Dict[str, Any]) -> Dict[str, Any]: + def _remap_field_names(self, fields: dict[str, Any]) -> dict[str, Any]: if not self.field_name_mappings: return fields diff --git a/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte_cdk/destinations/vector_db_based/embedder.py index 4ec56fbf..cba3e5f3 100644 --- a/airbyte_cdk/destinations/vector_db_based/embedder.py +++ b/airbyte_cdk/destinations/vector_db_based/embedder.py @@ -1,11 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import cast + +from langchain.embeddings.cohere import CohereEmbeddings +from langchain.embeddings.fake import FakeEmbeddings +from langchain.embeddings.localai import LocalAIEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings from airbyte_cdk.destinations.vector_db_based.config import ( AzureOpenAIEmbeddingConfigModel, @@ -19,10 +25,6 @@ from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType -from langchain.embeddings.cohere import CohereEmbeddings -from langchain.embeddings.fake import FakeEmbeddings -from langchain.embeddings.localai import LocalAIEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings @dataclass @@ -32,8 +34,7 @@ class Document: class Embedder(ABC): - """ - Embedder is an abstract class that defines the interface for embedding text. + """Embedder is an abstract class that defines the interface for embedding text. The Indexer class uses the Embedder class to internally embed text - each indexer is responsible to pass the text of all documents to the embedder and store the resulting embeddings in the destination. The destination connector is responsible to create an embedder instance and pass it to the writer. @@ -44,13 +45,12 @@ def __init__(self) -> None: pass @abstractmethod - def check(self) -> Optional[str]: + def check(self) -> str | None: pass @abstractmethod - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - """ - Embed the text of each chunk and return the resulting embedding vectors. + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: + """Embed the text of each chunk and return the resulting embedding vectors. If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk. """ pass @@ -72,16 +72,15 @@ def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int): self.embeddings = embeddings self.chunk_size = chunk_size - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - """ - Embed the text of each chunk and return the resulting embedding vectors. + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: + """Embed the text of each chunk and return the resulting embedding vectors. As the OpenAI API will fail if more than the per-minute limit worth of tokens is sent at once, we split the request into batches and embed each batch separately. It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls, @@ -90,7 +89,7 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float # Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size batches = create_chunks(documents, batch_size=embedding_batch_size) - embeddings: List[Optional[List[float]]] = [] + embeddings: list[list[float] | None] = [] for batch in batches: embeddings.extend( self.embeddings.embed_documents([chunk.page_content for chunk in batch]) @@ -142,16 +141,16 @@ def __init__(self, config: CohereEmbeddingConfigModel): cohere_api_key=config.cohere_key, model="embed-english-light-v2.0" ) # type: ignore - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -166,16 +165,16 @@ def __init__(self, config: FakeEmbeddingConfigModel): super().__init__() self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE) - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -202,7 +201,7 @@ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): disallowed_special=(), ) # type: ignore - def check(self) -> Optional[str]: + def check(self) -> str | None: deployment_mode = os.environ.get("DEPLOYMENT_MODE", "") if ( deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE @@ -216,9 +215,9 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -233,15 +232,14 @@ def __init__(self, config: FromFieldEmbeddingConfigModel): super().__init__() self.config = config - def check(self) -> Optional[str]: + def check(self) -> str | None: return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - """ - From each chunk, pull the embedding from the field specified in the config. + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: + """From each chunk, pull the embedding from the field specified in the config. Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem. """ - embeddings: List[Optional[List[float]]] = [] + embeddings: list[list[float] | None] = [] for document in documents: data = document.record.data if self.config.field_name not in data: @@ -283,14 +281,12 @@ def embedding_dimensions(self) -> int: def create_from_config( - embedding_config: Union[ - AzureOpenAIEmbeddingConfigModel, - CohereEmbeddingConfigModel, - FakeEmbeddingConfigModel, - FromFieldEmbeddingConfigModel, - OpenAIEmbeddingConfigModel, - OpenAICompatibleEmbeddingConfigModel, - ], + embedding_config: AzureOpenAIEmbeddingConfigModel + | CohereEmbeddingConfigModel + | FakeEmbeddingConfigModel + | FromFieldEmbeddingConfigModel + | OpenAIEmbeddingConfigModel + | OpenAICompatibleEmbeddingConfigModel, processing_config: ProcessingConfigModel, ) -> Embedder: if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai": @@ -298,5 +294,4 @@ def create_from_config( Embedder, embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size), ) - else: - return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) + return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) diff --git a/airbyte_cdk/destinations/vector_db_based/indexer.py b/airbyte_cdk/destinations/vector_db_based/indexer.py index c49f576a..1ce58965 100644 --- a/airbyte_cdk/destinations/vector_db_based/indexer.py +++ b/airbyte_cdk/destinations/vector_db_based/indexer.py @@ -1,18 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import itertools from abc import ABC, abstractmethod -from typing import Any, Generator, Iterable, List, Optional, Tuple, TypeVar +from collections.abc import Generator, Iterable +from typing import Any, TypeVar from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog class Indexer(ABC): - """ - Indexer is an abstract class that defines the interface for indexing documents. + """Indexer is an abstract class that defines the interface for indexing documents. The Writer class uses the Indexer class to internally index documents generated by the document processor. In a destination connector, implement a custom indexer by extending this class and implementing the abstract methods. @@ -23,24 +24,20 @@ def __init__(self, config: Any): pass def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None: - """ - Run before the sync starts. This method should be used to make sure all records in the destination that belong to streams with a destination mode of overwrite are deleted. + """Run before the sync starts. This method should be used to make sure all records in the destination that belong to streams with a destination mode of overwrite are deleted. Each record has a metadata field with the name airbyte_cdk.destinations.vector_db_based.document_processor.METADATA_STREAM_FIELD which can be used to filter documents for deletion. Use the airbyte_cdk.destinations.vector_db_based.utils.create_stream_identifier method to create the stream identifier based on the stream definition to use for filtering. """ pass - def post_sync(self) -> List[AirbyteMessage]: - """ - Run after the sync finishes. This method should be used to perform any cleanup operations and can return a list of AirbyteMessages to be logged. - """ + def post_sync(self) -> list[AirbyteMessage]: + """Run after the sync finishes. This method should be used to perform any cleanup operations and can return a list of AirbyteMessages to be logged.""" return [] @abstractmethod - def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> None: - """ - Index a list of document chunks. + def index(self, document_chunks: list[Chunk], namespace: str, stream: str) -> None: + """Index a list of document chunks. This method should be used to index the documents in the destination. If page_content is None, the document should be indexed without the raw text. All chunks belong to the stream and namespace specified in the parameters. @@ -48,9 +45,8 @@ def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> No pass @abstractmethod - def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None: - """ - Delete document chunks belonging to certain record ids. + def delete(self, delete_ids: list[str], namespace: str, stream: str) -> None: + """Delete document chunks belonging to certain record ids. This method should be used to delete documents from the destination. The delete_ids parameter contains a list of record ids - all chunks with a record id in this list should be deleted from the destination. @@ -59,17 +55,15 @@ def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None: pass @abstractmethod - def check(self) -> Optional[str]: - """ - Check if the indexer is configured correctly. This method should be used to check if the indexer is configured correctly and return an error message if it is not. - """ + def check(self) -> str | None: + """Check if the indexer is configured correctly. This method should be used to check if the indexer is configured correctly and return an error message if it is not.""" pass T = TypeVar("T") -def chunks(iterable: Iterable[T], batch_size: int) -> Generator[Tuple[T, ...], None, None]: +def chunks(iterable: Iterable[T], batch_size: int) -> Generator[tuple[T, ...], None, None]: """A helper function to break an iterable into chunks of size batch_size.""" it = iter(iterable) chunk = tuple(itertools.islice(it, batch_size)) diff --git a/airbyte_cdk/destinations/vector_db_based/test_utils.py b/airbyte_cdk/destinations/vector_db_based/test_utils.py index a2f3d3d8..80cc17d2 100644 --- a/airbyte_cdk/destinations/vector_db_based/test_utils.py +++ b/airbyte_cdk/destinations/vector_db_based/test_utils.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import unittest -from typing import Any, Dict +from typing import Any from airbyte_cdk.models import ( AirbyteMessage, @@ -20,8 +21,7 @@ class BaseIntegrationTest(unittest.TestCase): - """ - BaseIntegrationTest is a base class for integration tests for vector db destinations. + """BaseIntegrationTest is a base class for integration tests for vector db destinations. It provides helper methods to create Airbyte catalogs, records and state messages. """ @@ -47,7 +47,7 @@ def _get_configured_catalog( return ConfiguredAirbyteCatalog(streams=[overwrite_stream]) - def _state(self, data: Dict[str, Any]) -> AirbyteMessage: + def _state(self, data: dict[str, Any]) -> AirbyteMessage: return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data)) def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage: @@ -59,5 +59,5 @@ def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage ) def setUp(self) -> None: - with open("secrets/config.json", "r") as f: + with open("secrets/config.json") as f: self.config = json.loads(f.read()) diff --git a/airbyte_cdk/destinations/vector_db_based/utils.py b/airbyte_cdk/destinations/vector_db_based/utils.py index dbb1f471..288b2df6 100644 --- a/airbyte_cdk/destinations/vector_db_based/utils.py +++ b/airbyte_cdk/destinations/vector_db_based/utils.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import itertools import traceback -from typing import Any, Iterable, Iterator, Tuple, Union +from collections.abc import Iterable, Iterator +from typing import Any from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStream @@ -17,7 +19,7 @@ def format_exception(exception: Exception) -> str: ) -def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[Any, ...]]: +def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[tuple[Any, ...]]: """A helper function to break an iterable into chunks of size batch_size.""" it = iter(iterable) chunk = tuple(itertools.islice(it, batch_size)) @@ -26,10 +28,7 @@ def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[An chunk = tuple(itertools.islice(it, batch_size)) -def create_stream_identifier(stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str: +def create_stream_identifier(stream: AirbyteStream | AirbyteRecordMessage) -> str: if isinstance(stream, AirbyteStream): return str(stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}") - else: - return str( - stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}" - ) + return str(stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}") diff --git a/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte_cdk/destinations/vector_db_based/writer.py index 268e49ef..8a9a0da1 100644 --- a/airbyte_cdk/destinations/vector_db_based/writer.py +++ b/airbyte_cdk/destinations/vector_db_based/writer.py @@ -1,10 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations from collections import defaultdict -from typing import Dict, Iterable, List, Tuple +from collections.abc import Iterable from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor @@ -14,8 +14,7 @@ class Writer: - """ - The Writer class is orchestrating the document processor, the embedder and the indexer: + """The Writer class is orchestrating the document processor, the embedder and the indexer: * Incoming records are passed through the document processor to generate chunks * One the configured batch size is reached, the chunks are passed to the embedder to generate embeddings * The embedder embeds the chunks @@ -42,14 +41,12 @@ def __init__( self._init_batch() def _init_batch(self) -> None: - self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list) - self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list) + self.chunks: dict[tuple[str, str], list[Chunk]] = defaultdict(list) + self.ids_to_delete: dict[tuple[str, str], list[str]] = defaultdict(list) self.number_of_chunks = 0 def _convert_to_document(self, chunk: Chunk) -> Document: - """ - Convert a chunk to a document for the embedder. - """ + """Convert a chunk to a document for the embedder.""" if chunk.page_content is None: raise ValueError("Cannot embed a chunk without page content") return Document(page_content=chunk.page_content, record=chunk.record) @@ -83,9 +80,9 @@ def write( yield message elif message.type == Type.RECORD: record_chunks, record_id_to_delete = self.processor.process(message.record) - self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks) + self.chunks[message.record.namespace, message.record.stream].extend(record_chunks) if record_id_to_delete is not None: - self.ids_to_delete[(message.record.namespace, message.record.stream)].append( + self.ids_to_delete[message.record.namespace, message.record.stream].append( record_id_to_delete ) self.number_of_chunks += len(record_chunks) diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index 5a979a94..163d8fce 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import argparse import importlib @@ -11,11 +12,15 @@ import sys import tempfile from collections import defaultdict +from collections.abc import Iterable, Mapping from functools import wraps -from typing import Any, DefaultDict, Iterable, List, Mapping, Optional +from typing import Any from urllib.parse import urlparse import requests +from orjson import orjson +from requests import PreparedRequest, Response, Session + from airbyte_cdk.connector import TConfig from airbyte_cdk.exception_handler import init_uncaught_exception_handler from airbyte_cdk.logger import init_logger @@ -38,8 +43,7 @@ from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from orjson import orjson -from requests import PreparedRequest, Response, Session + logger = init_logger("airbyte") @@ -47,7 +51,7 @@ CLOUD_DEPLOYMENT_MODE = "cloud" -class AirbyteEntrypoint(object): +class AirbyteEntrypoint: def __init__(self, source: Source): init_uncaught_exception_handler(logger) @@ -59,7 +63,7 @@ def __init__(self, source: Source): self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}") @staticmethod - def parse_args(args: List[str]) -> argparse.Namespace: + def parse_args(args: list[str]) -> argparse.Namespace: # set up parent parsers parent_parser = argparse.ArgumentParser(add_help=False) parent_parser.add_argument( @@ -233,7 +237,7 @@ def read( self.validate_connection(source_spec, config) # The Airbyte protocol dictates that counts be expressed as float/double to better protect against integer overflows - stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float) + stream_message_counter: defaultdict[HashableStreamDescriptor, float] = defaultdict(float) for message in self.source.read(self.logger, config, catalog, state): yield self.handle_record_counts(message, stream_message_counter) for message in self._emit_queued_messages(self.source): @@ -241,7 +245,7 @@ def read( @staticmethod def handle_record_counts( - message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float] + message: AirbyteMessage, stream_message_count: defaultdict[HashableStreamDescriptor, float] ) -> AirbyteMessage: match message.type: case Type.RECORD: @@ -282,21 +286,21 @@ def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str: return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string @classmethod - def extract_state(cls, args: List[str]) -> Optional[Any]: + def extract_state(cls, args: list[str]) -> Any | None: parsed_args = cls.parse_args(args) if hasattr(parsed_args, "state"): return parsed_args.state return None @classmethod - def extract_catalog(cls, args: List[str]) -> Optional[Any]: + def extract_catalog(cls, args: list[str]) -> Any | None: parsed_args = cls.parse_args(args) if hasattr(parsed_args, "catalog"): return parsed_args.catalog return None @classmethod - def extract_config(cls, args: List[str]) -> Optional[Any]: + def extract_config(cls, args: list[str]) -> Any | None: parsed_args = cls.parse_args(args) if hasattr(parsed_args, "config"): return parsed_args.config @@ -308,7 +312,7 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]: return -def launch(source: Source, args: List[str]) -> None: +def launch(source: Source, args: list[str]) -> None: source_entrypoint = AirbyteEntrypoint(source) parsed_args = source_entrypoint.parse_args(args) # temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs @@ -321,9 +325,7 @@ def launch(source: Source, args: List[str]) -> None: def _init_internal_request_filter() -> None: - """ - Wraps the Python requests library to prevent sending requests to internal URL endpoints. - """ + """Wraps the Python requests library to prevent sending requests to internal URL endpoints.""" wrapped_fn = Session.send @wraps(wrapped_fn) @@ -361,9 +363,7 @@ def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Respons def _is_private_url(hostname: str, port: int) -> bool: - """ - Helper method that checks if any of the IP addresses associated with a hostname belong to a private network. - """ + """Helper method that checks if any of the IP addresses associated with a hostname belong to a private network.""" address_info_entries = socket.getaddrinfo(hostname, port) for entry in address_info_entries: # getaddrinfo() returns entries in the form of a 5-tuple where the IP is stored as the sockaddr. For IPv4 this diff --git a/airbyte_cdk/exception_handler.py b/airbyte_cdk/exception_handler.py index 84aa39ba..7525bb85 100644 --- a/airbyte_cdk/exception_handler.py +++ b/airbyte_cdk/exception_handler.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import sys +from collections.abc import Mapping from types import TracebackType -from typing import Any, List, Mapping, Optional +from typing import Any from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -20,15 +22,14 @@ def assemble_uncaught_exception( def init_uncaught_exception_handler(logger: logging.Logger) -> None: - """ - Handles uncaught exceptions by emitting an AirbyteTraceMessage and making sure they are not + """Handles uncaught exceptions by emitting an AirbyteTraceMessage and making sure they are not printed to the console without having secrets removed. """ def hook_fn( exception_type: type[BaseException], exception_value: BaseException, - traceback_: Optional[TracebackType], + traceback_: TracebackType | None, ) -> Any: # For developer ergonomics, we want to see the stack trace in the logs when we do a ctrl-c if issubclass(exception_type, KeyboardInterrupt): @@ -45,7 +46,7 @@ def hook_fn( sys.excepthook = hook_fn -def generate_failed_streams_error_message(stream_failures: Mapping[str, List[Exception]]) -> str: +def generate_failed_streams_error_message(stream_failures: Mapping[str, list[Exception]]) -> str: failures = "\n".join( [ f"{stream}: {filter_secrets(exception.__repr__())}" diff --git a/airbyte_cdk/logger.py b/airbyte_cdk/logger.py index 055d80e8..67c8588e 100644 --- a/airbyte_cdk/logger.py +++ b/airbyte_cdk/logger.py @@ -1,11 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging import logging.config -from typing import Any, Callable, Mapping, Optional, Tuple +from collections.abc import Callable, Mapping +from typing import Any + +from orjson import orjson from airbyte_cdk.models import ( AirbyteLogMessage, @@ -15,7 +19,7 @@ Type, ) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets -from orjson import orjson + LOGGING_CONFIG = { "version": 1, @@ -36,7 +40,7 @@ } -def init_logger(name: Optional[str] = None) -> logging.Logger: +def init_logger(name: str | None = None) -> logging.Logger: """Initial set up of logger""" logger = logging.getLogger(name) logger.setLevel(logging.INFO) @@ -45,9 +49,7 @@ def init_logger(name: Optional[str] = None) -> logging.Logger: def lazy_log(logger: logging.Logger, level: int, lazy_log_provider: Callable[[], str]) -> None: - """ - This method ensure that the processing of the log message is only done if the logger is enabled for the log level. - """ + """This method ensure that the processing of the log message is only done if the logger is enabled for the log level.""" if logger.isEnabledFor(level): logger.log(level, lazy_log_provider()) @@ -71,18 +73,16 @@ def format(self, record: logging.LogRecord) -> str: extras = self.extract_extra_args_from_record(record) debug_dict = {"type": "DEBUG", "message": record.getMessage(), "data": extras} return filter_secrets(json.dumps(debug_dict)) - else: - message = super().format(record) - message = filter_secrets(message) - log_message = AirbyteMessage( - type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message) - ) - return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string + message = super().format(record) + message = filter_secrets(message) + log_message = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message) + ) + return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string @staticmethod def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]: - """ - The python logger conflates default args with extra args. We use an empty log record and set operations + """The python logger conflates default args with extra args. We use an empty log record and set operations to isolate fields passed to the log record via extra by the developer. """ default_attrs = logging.LogRecord("", 0, "", 0, None, None, None).__dict__.keys() @@ -90,7 +90,7 @@ def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, An return {k: str(getattr(record, k)) for k in extra_keys if hasattr(record, k)} -def log_by_prefix(msg: str, default_level: str) -> Tuple[int, str]: +def log_by_prefix(msg: str, default_level: str) -> tuple[int, str]: """Custom method, which takes log level from first word of message""" valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"] split_line = msg.split() diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py index 6be79948..ca390e5c 100644 --- a/airbyte_cdk/models/airbyte_protocol.py +++ b/airbyte_cdk/models/airbyte_protocol.py @@ -1,21 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Annotated, Any, Dict, List, Mapping, Optional, Union +from typing import Annotated, Any -from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage -from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*' from serpyco_rs.metadata import Alias +from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*' + +from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage + + # ruff: noqa: F405 # ignore fuzzy import issues with 'import *' @dataclass class AirbyteStateBlob: - """ - A dataclass that dynamically sets attributes based on provided keyword arguments and positional arguments. + """A dataclass that dynamically sets attributes based on provided keyword arguments and positional arguments. Used to "mimic" pydantic Basemodel with ConfigDict(extra='allow') option. The `AirbyteStateBlob` class allows for flexible instantiation by accepting any number of keyword arguments @@ -55,35 +59,35 @@ def __eq__(self, other: object) -> bool: @dataclass class AirbyteStreamState: stream_descriptor: StreamDescriptor # type: ignore [name-defined] - stream_state: Optional[AirbyteStateBlob] = None + stream_state: AirbyteStateBlob | None = None @dataclass class AirbyteGlobalState: - stream_states: List[AirbyteStreamState] - shared_state: Optional[AirbyteStateBlob] = None + stream_states: list[AirbyteStreamState] + shared_state: AirbyteStateBlob | None = None @dataclass class AirbyteStateMessage: - type: Optional[AirbyteStateType] = None # type: ignore [name-defined] - stream: Optional[AirbyteStreamState] = None + type: AirbyteStateType | None = None # type: ignore [name-defined] + stream: AirbyteStreamState | None = None global_: Annotated[AirbyteGlobalState | None, Alias("global")] = ( None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization ) - data: Optional[Dict[str, Any]] = None - sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined] - destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined] + data: dict[str, Any] | None = None + sourceStats: AirbyteStateStats | None = None # type: ignore [name-defined] + destinationStats: AirbyteStateStats | None = None # type: ignore [name-defined] @dataclass class AirbyteMessage: type: Type # type: ignore [name-defined] - log: Optional[AirbyteLogMessage] = None # type: ignore [name-defined] - spec: Optional[ConnectorSpecification] = None # type: ignore [name-defined] - connectionStatus: Optional[AirbyteConnectionStatus] = None # type: ignore [name-defined] - catalog: Optional[AirbyteCatalog] = None # type: ignore [name-defined] - record: Optional[Union[AirbyteFileTransferRecordMessage, AirbyteRecordMessage]] = None # type: ignore [name-defined] - state: Optional[AirbyteStateMessage] = None - trace: Optional[AirbyteTraceMessage] = None # type: ignore [name-defined] - control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined] + log: AirbyteLogMessage | None = None # type: ignore [name-defined] + spec: ConnectorSpecification | None = None # type: ignore [name-defined] + connectionStatus: AirbyteConnectionStatus | None = None # type: ignore [name-defined] + catalog: AirbyteCatalog | None = None # type: ignore [name-defined] + record: AirbyteFileTransferRecordMessage | AirbyteRecordMessage | None = None # type: ignore [name-defined] + state: AirbyteStateMessage | None = None + trace: AirbyteTraceMessage | None = None # type: ignore [name-defined] + control: AirbyteControlMessage | None = None # type: ignore [name-defined] diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py index 129556ac..8ae8f2cc 100644 --- a/airbyte_cdk/models/airbyte_protocol_serializers.py +++ b/airbyte_cdk/models/airbyte_protocol_serializers.py @@ -1,5 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from serpyco_rs import CustomType, Serializer @@ -14,19 +16,19 @@ ) -class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]): - def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]: +class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]): + def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]: # cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete" return {k: v for k, v in value.__dict__.items()} - def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob: + def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob: return AirbyteStateBlob(value) - def get_json_schema(self) -> Dict[str, Any]: + def get_json_schema(self) -> dict[str, Any]: return {"type": "object"} -def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any]] | None: +def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None: return AirbyteStateBlobType() if t is AirbyteStateBlob else None diff --git a/airbyte_cdk/models/file_transfer_record_message.py b/airbyte_cdk/models/file_transfer_record_message.py index dcc1b7a9..5a30bd69 100644 --- a/airbyte_cdk/models/file_transfer_record_message.py +++ b/airbyte_cdk/models/file_transfer_record_message.py @@ -1,13 +1,14 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any @dataclass class AirbyteFileTransferRecordMessage: stream: str - file: Dict[str, Any] + file: dict[str, Any] emitted_at: int - namespace: Optional[str] = None - data: Optional[Dict[str, Any]] = None + namespace: str | None = None + data: dict[str, Any] | None = None diff --git a/airbyte_cdk/models/well_known_types.py b/airbyte_cdk/models/well_known_types.py index 7b1ea492..a695af3a 100644 --- a/airbyte_cdk/models/well_known_types.py +++ b/airbyte_cdk/models/well_known_types.py @@ -1,5 +1,6 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_protocol_dataclasses.models.well_known_types import * # noqa: F403 # Allow '*' diff --git a/airbyte_cdk/sources/abstract_source.py b/airbyte_cdk/sources/abstract_source.py index 34ba816b..e24a6967 100644 --- a/airbyte_cdk/sources/abstract_source.py +++ b/airbyte_cdk/sources/abstract_source.py @@ -1,20 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Mapping, MutableMapping from typing import ( Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - Union, ) from airbyte_cdk.exception_handler import generate_failed_streams_error_message @@ -46,21 +39,20 @@ ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException + _default_message_repository = InMemoryMessageRepository() class AbstractSource(Source, ABC): - """ - Abstract base class for an Airbyte Source. Consumers should implement any abstract methods + """Abstract base class for an Airbyte Source. Consumers should implement any abstract methods in this class to create an Airbyte Specification compliant Source. """ @abstractmethod def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: - """ - :param logger: source logger + ) -> tuple[bool, Any | None]: + """:param logger: source logger :param config: The user-provided configuration as specified by the source's spec. This usually contains information required to check connection e.g. tokens, secrets and keys etc. :return: A tuple of (boolean, error). If boolean is true, then the connection check is successful @@ -71,15 +63,14 @@ def check_connection( """ @abstractmethod - def streams(self, config: Mapping[str, Any]) -> List[Stream]: - """ - :param config: The user-provided configuration as specified by the source's spec. + def streams(self, config: Mapping[str, Any]) -> list[Stream]: + """:param config: The user-provided configuration as specified by the source's spec. Any stream construction related operation should happen here. :return: A list of the streams in this source connector. """ # Stream name to instance map for applying output object transformation - _stream_to_instance_map: Dict[str, Stream] = {} + _stream_to_instance_map: dict[str, Stream] = {} _slice_logger: SliceLogger = DebugSliceLogger() def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: @@ -103,7 +94,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: """Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/.""" logger.info(f"Starting syncing {self.name}") @@ -212,7 +203,7 @@ def read( @staticmethod def _serialize_exception( - stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Optional[Stream] = None + stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Stream | None = None ) -> AirbyteTracedException: display_message = stream_instance.get_error_display_message(e) if stream_instance else None if display_message: @@ -294,11 +285,9 @@ def _emit_queued_messages(self) -> Iterable[AirbyteMessage]: return def _get_message( - self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream + self, record_data_or_message: StreamData | AirbyteMessage, stream: Stream ) -> AirbyteMessage: - """ - Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage - """ + """Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage""" match record_data_or_message: case AirbyteMessage(): return record_data_or_message @@ -311,13 +300,12 @@ def _get_message( ) @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> None | MessageRepository: return _default_message_repository @property def stop_sync_on_stream_failure(self) -> bool: - """ - WARNING: This function is in-development which means it is subject to change. Use at your own risk. + """WARNING: This function is in-development which means it is subject to change. Use at your own risk. By default, when a source encounters an exception while syncing a stream, it will emit an error trace message and then continue syncing the next stream. This can be overwritten on a per-source basis so that the source will stop the sync diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 1f4a1b81..3940a077 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging -from typing import Dict, Iterable, List, Optional, Set +from collections.abc import Iterable from airbyte_cdk.exception_handler import generate_failed_streams_error_message from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus, FailureType, StreamDescriptor @@ -30,7 +32,7 @@ class ConcurrentReadProcessor: def __init__( self, - stream_instances_to_read_from: List[AbstractStream], + stream_instances_to_read_from: list[AbstractStream], partition_enqueuer: PartitionEnqueuer, thread_pool_manager: ThreadPoolManager, logger: logging.Logger, @@ -38,8 +40,7 @@ def __init__( message_repository: MessageRepository, partition_reader: PartitionReader, ): - """ - This class is responsible for handling items from a concurrent stream read process. + """This class is responsible for handling items from a concurrent stream read process. :param stream_instances_to_read_from: List of streams to read from :param partition_enqueuer: PartitionEnqueuer instance :param thread_pool_manager: ThreadPoolManager instance @@ -50,26 +51,25 @@ def __init__( """ self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from} self._record_counter = {} - self._streams_to_running_partitions: Dict[str, Set[Partition]] = {} + self._streams_to_running_partitions: dict[str, set[Partition]] = {} for stream in stream_instances_to_read_from: self._streams_to_running_partitions[stream.name] = set() self._record_counter[stream.name] = 0 self._thread_pool_manager = thread_pool_manager self._partition_enqueuer = partition_enqueuer self._stream_instances_to_start_partition_generation = stream_instances_to_read_from - self._streams_currently_generating_partitions: List[str] = [] + self._streams_currently_generating_partitions: list[str] = [] self._logger = logger self._slice_logger = slice_logger self._message_repository = message_repository self._partition_reader = partition_reader - self._streams_done: Set[str] = set() - self._exceptions_per_stream_name: dict[str, List[Exception]] = {} + self._streams_done: set[str] = set() + self._exceptions_per_stream_name: dict[str, list[Exception]] = {} def on_partition_generation_completed( self, sentinel: PartitionGenerationCompletedSentinel ) -> Iterable[AirbyteMessage]: - """ - This method is called when a partition generation is completed. + """This method is called when a partition generation is completed. 1. Remove the stream from the list of streams currently generating partitions 2. If the stream is done, mark it as such and return a stream status message 3. If there are more streams to read from, start the next partition generator @@ -87,8 +87,7 @@ def on_partition_generation_completed( yield self.start_next_partition_generator() # type:ignore # None may be yielded def on_partition(self, partition: Partition) -> None: - """ - This method is called when a partition is generated. + """This method is called when a partition is generated. 1. Add the partition to the set of partitions for the stream 2. Log the slice if necessary 3. Submit the partition to the thread pool manager @@ -104,8 +103,7 @@ def on_partition(self, partition: Partition) -> None: def on_partition_complete_sentinel( self, sentinel: PartitionCompleteSentinel ) -> Iterable[AirbyteMessage]: - """ - This method is called when a partition is completed. + """This method is called when a partition is completed. 1. Close the partition 2. If the stream is done, mark it as such and return a stream status message 3. Emit messages that were added to the message repository @@ -133,8 +131,7 @@ def on_partition_complete_sentinel( yield from self._message_repository.consume_queue() def on_record(self, record: Record) -> Iterable[AirbyteMessage]: - """ - This method is called when a record is read from a partition. + """This method is called when a record is read from a partition. 1. Convert the record to an AirbyteMessage 2. If this is the first record for the stream, mark the stream as RUNNING 3. Increment the record counter for the stream @@ -164,8 +161,7 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: yield from self._message_repository.consume_queue() def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMessage]: - """ - This method is called when an exception is raised. + """This method is called when an exception is raised. 1. Stop all running streams 2. Raise the exception """ @@ -185,9 +181,8 @@ def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMess def _flag_exception(self, stream_name: str, exception: Exception) -> None: self._exceptions_per_stream_name.setdefault(stream_name, []).append(exception) - def start_next_partition_generator(self) -> Optional[AirbyteMessage]: - """ - Start the next partition generator. + def start_next_partition_generator(self) -> AirbyteMessage | None: + """Start the next partition generator. 1. Pop the next stream to read from 2. Submit the partition generator to the thread pool manager 3. Add the stream to the list of streams currently generating partitions @@ -203,12 +198,10 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]: stream.as_airbyte_stream(), AirbyteStreamStatus.STARTED, ) - else: - return None + return None def is_done(self) -> bool: - """ - This method is called to check if the sync is done. + """This method is called to check if the sync is done. The sync is done when: 1. There are no more streams generating partitions 2. There are no more streams to read from diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index e5540799..45a206d2 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import concurrent import logging +from collections.abc import Iterable, Iterator from queue import Queue -from typing import Iterable, Iterator, List from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor @@ -27,8 +29,7 @@ class ConcurrentSource: - """ - A Source that reads data from multiple AbstractStreams concurrently. + """A Source that reads data from multiple AbstractStreams concurrently. It does so by submitting partition generation, and partition read tasks to a thread pool. The tasks asynchronously add their output to a shared queue. The read is done when all partitions for all streams w ere generated and read. @@ -44,7 +45,7 @@ def create( slice_logger: SliceLogger, message_repository: MessageRepository, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, - ) -> "ConcurrentSource": + ) -> ConcurrentSource: is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 too_many_generator = ( not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers @@ -76,8 +77,7 @@ def __init__( initial_number_partitions_to_generate: int = 1, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> None: - """ - :param threadpool: The threadpool to submit tasks to + """:param threadpool: The threadpool to submit tasks to :param logger: The logger to log to :param slice_logger: The slice logger used to create messages on new slices :param message_repository: The repository to emit messages to @@ -93,7 +93,7 @@ def __init__( def read( self, - streams: List[AbstractStream], + streams: list[AbstractStream], ) -> Iterator[AirbyteMessage]: self._logger.info("Starting syncing") diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py index c150dc95..fd184423 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC +from collections.abc import Callable, Iterator, Mapping, MutableMapping from datetime import timedelta -from typing import Any, Callable, Iterator, List, Mapping, MutableMapping, Optional, Tuple +from typing import Any from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog from airbyte_cdk.sources import AbstractSource @@ -27,13 +29,13 @@ AbstractStreamStateConverter, ) + DEFAULT_LOOKBACK_SECONDS = 0 class ConcurrentSourceAdapter(AbstractSource, ABC): def __init__(self, concurrent_source: ConcurrentSource, **kwargs: Any) -> None: - """ - ConcurrentSourceAdapter is a Source that wraps a concurrent source and exposes it as a regular source. + """ConcurrentSourceAdapter is a Source that wraps a concurrent source and exposes it as a regular source. The source's streams are still defined through the streams() method. Streams wrapped in a StreamFacade will be processed concurrently. @@ -47,7 +49,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: abstract_streams = self._select_abstract_streams(config, catalog) concurrent_stream_names = {stream.name for stream in abstract_streams} @@ -65,13 +67,11 @@ def read( def _select_abstract_streams( self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog - ) -> List[AbstractStream]: - """ - Selects streams that can be processed concurrently and returns their abstract representations. - """ + ) -> list[AbstractStream]: + """Selects streams that can be processed concurrently and returns their abstract representations.""" all_streams = self.streams(config) stream_name_to_instance: Mapping[str, Stream] = {s.name: s for s in all_streams} - abstract_streams: List[AbstractStream] = [] + abstract_streams: list[AbstractStream] = [] for configured_stream in configured_catalog.streams: stream_instance = stream_name_to_instance.get(configured_stream.stream.name) if not stream_instance: @@ -86,10 +86,9 @@ def convert_to_concurrent_stream( logger: logging.Logger, stream: Stream, state_manager: ConnectorStateManager, - cursor: Optional[Cursor] = None, + cursor: Cursor | None = None, ) -> Stream: - """ - Prepares a stream for concurrent processing by initializing or assigning a cursor, + """Prepares a stream for concurrent processing by initializing or assigning a cursor, managing the stream's state, and returning an updated Stream instance. """ state: MutableMapping[str, Any] = {} @@ -113,12 +112,12 @@ def initialize_cursor( stream: Stream, state_manager: ConnectorStateManager, converter: AbstractStreamStateConverter, - slice_boundary_fields: Optional[Tuple[str, str]], - start: Optional[CursorValueType], + slice_boundary_fields: tuple[str, str] | None, + start: CursorValueType | None, end_provider: Callable[[], CursorValueType], - lookback_window: Optional[GapType] = None, - slice_range: Optional[GapType] = None, - ) -> Optional[ConcurrentCursor]: + lookback_window: GapType | None = None, + slice_range: GapType | None = None, + ) -> ConcurrentCursor | None: lookback_window = lookback_window or timedelta(seconds=DEFAULT_LOOKBACK_SECONDS) cursor_field_name = stream.cursor_field diff --git a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py index b6643042..98265dcf 100644 --- a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py +++ b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py @@ -1,24 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any +from __future__ import annotations from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream class PartitionGenerationCompletedSentinel: - """ - A sentinel object indicating all partitions for a stream were produced. + """A sentinel object indicating all partitions for a stream were produced. Includes a pointer to the stream that was processed. """ def __init__(self, stream: AbstractStream): - """ - :param stream: The stream that was processed - """ + """:param stream: The stream that was processed""" self.stream = stream - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, PartitionGenerationCompletedSentinel): return self.stream == other.stream return False diff --git a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py index c865bef5..5667a1e8 100644 --- a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py +++ b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py @@ -1,6 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - -from typing import Any +from __future__ import annotations class StreamThreadException(Exception): @@ -19,7 +18,7 @@ def exception(self) -> Exception: def __str__(self) -> str: return f"Exception while syncing stream {self._stream_name}: {self._exception}" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, StreamThreadException): return self._exception == other._exception and self._stream_name == other._stream_name return False diff --git a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py index 59f8a1f0..1c59325d 100644 --- a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +++ b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py @@ -1,16 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import threading +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, List, Optional +from typing import Any class ThreadPoolManager: - """ - Wrapper to abstract away the threadpool and the logic to wait for pending tasks to be completed. - """ + """Wrapper to abstract away the threadpool and the logic to wait for pending tasks to be completed.""" DEFAULT_MAX_QUEUE_SIZE = 10_000 @@ -20,17 +21,16 @@ def __init__( logger: logging.Logger, max_concurrent_tasks: int = DEFAULT_MAX_QUEUE_SIZE, ): - """ - :param threadpool: The threadpool to use + """:param threadpool: The threadpool to use :param logger: The logger to use :param max_concurrent_tasks: The maximum number of tasks that can be pending at the same time """ self._threadpool = threadpool self._logger = logger self._max_concurrent_tasks = max_concurrent_tasks - self._futures: List[Future[Any]] = [] + self._futures: list[Future[Any]] = [] self._lock = threading.Lock() - self._most_recently_seen_exception: Optional[Exception] = None + self._most_recently_seen_exception: Exception | None = None self._logging_threshold = max_concurrent_tasks * 2 @@ -45,9 +45,8 @@ def prune_to_validate_has_reached_futures_limit(self) -> bool: def submit(self, function: Callable[..., Any], *args: Any) -> None: self._futures.append(self._threadpool.submit(function, *args)) - def _prune_futures(self, futures: List[Future[Any]]) -> None: - """ - Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream + def _prune_futures(self, futures: list[Future[Any]]) -> None: + """Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream operation. We are using a lock here as without it, the algorithm would not be thread safe @@ -82,8 +81,7 @@ def is_done(self) -> bool: return all([f.done() for f in self._futures]) def check_for_errors_and_shutdown(self) -> None: - """ - Check if any of the futures have an exception, and raise it if so. If all futures are done, shutdown the threadpool. + """Check if any of the futures have an exception, and raise it if so. If all futures are done, shutdown the threadpool. If the futures are not done, raise an exception. :return: """ diff --git a/airbyte_cdk/sources/config.py b/airbyte_cdk/sources/config.py index 8ea2b640..209ed17f 100644 --- a/airbyte_cdk/sources/config.py +++ b/airbyte_cdk/sources/config.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Dict +from typing import Any -from airbyte_cdk.sources.utils.schema_helpers import expand_refs, rename_key from pydantic.v1 import BaseModel +from airbyte_cdk.sources.utils.schema_helpers import expand_refs, rename_key + class BaseConfig(BaseModel): """Base class for connector spec, adds the following behaviour: @@ -17,7 +19,7 @@ class BaseConfig(BaseModel): """ @classmethod - def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: """We're overriding the schema classmethod to enable some post-processing""" schema = super().schema(*args, **kwargs) rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf diff --git a/airbyte_cdk/sources/connector_state_manager.py b/airbyte_cdk/sources/connector_state_manager.py index 56b58127..c700f8bb 100644 --- a/airbyte_cdk/sources/connector_state_manager.py +++ b/airbyte_cdk/sources/connector_state_manager.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any from airbyte_cdk.models import ( AirbyteMessage, @@ -19,22 +21,20 @@ @dataclass(frozen=True) class HashableStreamDescriptor: - """ - Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and + """Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests. """ name: str - namespace: Optional[str] = None + namespace: str | None = None class ConnectorStateManager: - """ - ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common + """ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common interface. It also provides methods to extract and update state """ - def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): + def __init__(self, state: list[AirbyteStateMessage] | None = None): shared_state, per_stream_states = self._extract_from_state_message(state) # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are @@ -49,11 +49,8 @@ def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): ) self.per_stream_states = per_stream_states - def get_stream_state( - self, stream_name: str, namespace: Optional[str] - ) -> MutableMapping[str, Any]: - """ - Retrieves the state of a given stream based on its descriptor (name + namespace). + def get_stream_state(self, stream_name: str, namespace: str | None) -> MutableMapping[str, Any]: + """Retrieves the state of a given stream based on its descriptor (name + namespace). :param stream_name: Name of the stream being fetched :param namespace: Namespace of the stream being fetched :return: The per-stream state for a stream @@ -66,10 +63,9 @@ def get_stream_state( return {} def update_state_for_stream( - self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] + self, stream_name: str, namespace: str | None, value: Mapping[str, Any] ) -> None: - """ - Overwrites the state blob of a specific stream based on the provided stream name and optional namespace + """Overwrites the state blob of a specific stream based on the provided stream name and optional namespace :param stream_name: The name of the stream whose state is being updated :param namespace: The namespace of the stream if it exists :param value: A stream state mapping that is being updated for a stream @@ -77,9 +73,8 @@ def update_state_for_stream( stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value) - def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage: - """ - Generates an AirbyteMessage using the current per-stream state of a specified stream + def create_state_message(self, stream_name: str, namespace: str | None) -> AirbyteMessage: + """Generates an AirbyteMessage using the current per-stream state of a specified stream :param stream_name: The name of the stream for the message that is being created :param namespace: The namespace of the stream for the message that is being created :return: The Airbyte state message to be emitted by the connector during a sync @@ -101,13 +96,12 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai @classmethod def _extract_from_state_message( cls, - state: Optional[List[AirbyteStateMessage]], - ) -> Tuple[ - Optional[AirbyteStateBlob], - MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]], + state: list[AirbyteStateMessage] | None, + ) -> tuple[ + AirbyteStateBlob | None, + MutableMapping[HashableStreamDescriptor, AirbyteStateBlob | None], ]: - """ - Takes an incoming list of state messages or a global state message and extracts state attributes according to + """Takes an incoming list of state messages or a global state message and extracts state attributes according to type which can then be assigned to the new state manager being instantiated :param state: The incoming state input :return: A tuple of shared state and per stream state assembled from the incoming state list @@ -128,22 +122,21 @@ def _extract_from_state_message( for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state } return shared_state, streams - else: - streams = { - HashableStreamDescriptor( - name=per_stream_state.stream.stream_descriptor.name, - namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor - ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state - for per_stream_state in state - if per_stream_state.type == AirbyteStateType.STREAM - and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True - } - return None, streams + streams = { + HashableStreamDescriptor( + name=per_stream_state.stream.stream_descriptor.name, + namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor + ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state + for per_stream_state in state + if per_stream_state.type == AirbyteStateType.STREAM + and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True + } + return None, streams @staticmethod - def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool: + def _is_global_state(state: list[AirbyteStateMessage] | MutableMapping[str, Any]) -> bool: return ( - isinstance(state, List) + isinstance(state, list) and len(state) == 1 and isinstance(state[0], AirbyteStateMessage) and state[0].type == AirbyteStateType.GLOBAL @@ -151,6 +144,6 @@ def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, @staticmethod def _is_per_stream_state( - state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], + state: list[AirbyteStateMessage] | MutableMapping[str, Any], ) -> bool: - return isinstance(state, List) + return isinstance(state, list) diff --git a/airbyte_cdk/sources/declarative/async_job/job.py b/airbyte_cdk/sources/declarative/async_job/job.py index b075b61e..3c561bbb 100644 --- a/airbyte_cdk/sources/declarative/async_job/job.py +++ b/airbyte_cdk/sources/declarative/async_job/job.py @@ -1,31 +1,28 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - +from __future__ import annotations from datetime import timedelta -from typing import Optional +from .status import AsyncJobStatus from airbyte_cdk.sources.declarative.async_job.timer import Timer from airbyte_cdk.sources.types import StreamSlice -from .status import AsyncJobStatus - class AsyncJob: - """ - Description of an API job. + """Description of an API job. Note that the timer will only stop once `update_status` is called so the job might be completed on the API side but until we query for it and call `ApiJob.update_status`, `ApiJob.status` will not reflect the actual API side status. """ def __init__( - self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None + self, api_job_id: str, job_parameters: StreamSlice, timeout: timedelta | None = None ) -> None: self._api_job_id = api_job_id self._job_parameters = job_parameters self._status = AsyncJobStatus.RUNNING - timeout = timeout if timeout else timedelta(minutes=60) + timeout = timeout or timedelta(minutes=60) self._timer = Timer(timeout) self._timer.start() diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index d94885fa..e485e5a1 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -1,22 +1,16 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import logging import threading import time import traceback import uuid +from collections.abc import Generator, Iterable, Mapping from datetime import timedelta from typing import ( Any, - Generator, Generic, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - Type, TypeVar, ) @@ -34,20 +28,19 @@ from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException + LOGGER = logging.getLogger("airbyte") _NO_TIMEOUT = timedelta.max _API_SIDE_RUNNING_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} class AsyncPartition: - """ - This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs - """ + """This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs""" _MAX_NUMBER_OF_ATTEMPTS = 3 - def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None: - self._attempts_per_job = {job: 1 for job in jobs} + def __init__(self, jobs: list[AsyncJob], stream_slice: StreamSlice) -> None: + self._attempts_per_job = dict.fromkeys(jobs, 1) self._stream_slice = stream_slice def has_reached_max_attempt(self) -> bool: @@ -58,11 +51,11 @@ def has_reached_max_attempt(self) -> bool: ) ) - def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None: + def replace_job(self, job_to_replace: AsyncJob, new_jobs: list[AsyncJob]) -> None: current_attempt_count = self._attempts_per_job.pop(job_to_replace, None) if current_attempt_count is None: raise ValueError("Could not find job to replace") - elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS: + if current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS: raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}") new_attempt_count = current_attempt_count + 1 @@ -70,9 +63,7 @@ def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> Non self._attempts_per_job[job] = new_attempt_count def should_split(self, job: AsyncJob) -> bool: - """ - Not used right now but once we support job split, we should split based on the number of attempts - """ + """Not used right now but once we support job split, we should split based on the number of attempts""" return False @property @@ -85,18 +76,15 @@ def stream_slice(self) -> StreamSlice: @property def status(self) -> AsyncJobStatus: - """ - Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed. - """ + """Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed.""" statuses = set(map(lambda job: job.status(), self.jobs)) if statuses == {AsyncJobStatus.COMPLETED}: return AsyncJobStatus.COMPLETED - elif AsyncJobStatus.FAILED in statuses: + if AsyncJobStatus.FAILED in statuses: return AsyncJobStatus.FAILED - elif AsyncJobStatus.TIMED_OUT in statuses: + if AsyncJobStatus.TIMED_OUT in statuses: return AsyncJobStatus.TIMED_OUT - else: - return AsyncJobStatus.RUNNING + return AsyncJobStatus.RUNNING def __repr__(self) -> str: return f"AsyncPartition(stream_slice={self._stream_slice}, attempt_per_job={self._attempts_per_job})" @@ -111,16 +99,15 @@ def __json_serializable__(self) -> Any: class LookaheadIterator(Generic[T]): def __init__(self, iterable: Iterable[T]) -> None: self._iterator = iter(iterable) - self._buffer: List[T] = [] + self._buffer: list[T] = [] - def __iter__(self) -> "LookaheadIterator[T]": + def __iter__(self) -> LookaheadIterator[T]: return self def __next__(self) -> T: if self._buffer: return self._buffer.pop() - else: - return next(self._iterator) + return next(self._iterator) def has_next(self) -> bool: if self._buffer: @@ -153,11 +140,10 @@ def __init__( slices: Iterable[StreamSlice], job_tracker: JobTracker, message_repository: MessageRepository, - exceptions_to_break_on: Iterable[Type[Exception]] = tuple(), + exceptions_to_break_on: Iterable[type[Exception]] = tuple(), has_bulk_parent: bool = False, ) -> None: - """ - If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent` + """If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent` needs to be set to True as jobs creation needs to be prioritized on the parent level. Doing otherwise could lead to a situation where the child has taken up all the job budget without room to the parent to create more which would lead to an infinite loop of "trying to start a parent job" and "ConcurrentJobLimitReached". @@ -170,13 +156,13 @@ def __init__( self._job_repository: AsyncJobRepository = job_repository self._slice_iterator = LookaheadIterator(slices) - self._running_partitions: List[AsyncPartition] = [] + self._running_partitions: list[AsyncPartition] = [] self._job_tracker = job_tracker self._message_repository = message_repository - self._exceptions_to_break_on: Tuple[Type[Exception], ...] = tuple(exceptions_to_break_on) + self._exceptions_to_break_on: tuple[type[Exception], ...] = tuple(exceptions_to_break_on) self._has_bulk_parent = has_bulk_parent - self._non_breaking_exceptions: List[Exception] = [] + self._non_breaking_exceptions: list[Exception] = [] def _replace_failed_jobs(self, partition: AsyncPartition) -> None: failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT) @@ -186,10 +172,10 @@ def _replace_failed_jobs(self, partition: AsyncPartition) -> None: partition.replace_job(job, [new_job]) def _start_jobs(self) -> None: - """ - Retry failed jobs and start jobs for each slice in the slice iterator. + """Retry failed jobs and start jobs for each slice in the slice iterator. This method iterates over the running jobs and slice iterator and starts a job for each slice. The started jobs are added to the running partitions. + Returns: None @@ -225,7 +211,7 @@ def _start_jobs(self) -> None: "Waiting before creating more jobs as the limit of concurrent jobs has been reached. Will try again later..." ) - def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) -> AsyncJob: + def _start_job(self, _slice: StreamSlice, previous_job_id: str | None = None) -> AsyncJob: if previous_job_id: id_to_replace = previous_job_id lazy_log(LOGGER, logging.DEBUG, lambda: f"Attempting to replace job {id_to_replace}...") @@ -246,8 +232,7 @@ def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) def _keep_api_budget_with_failed_job( self, _slice: StreamSlice, exception: Exception, intent: str ) -> AsyncJob: - """ - We have a mechanism to retry job. It is used when a job status is FAILED or TIMED_OUT. The easiest way to retry is to have this job + """We have a mechanism to retry job. It is used when a job status is FAILED or TIMED_OUT. The easiest way to retry is to have this job as created in a failed state and leverage the retry for failed/timed out jobs. This way, we don't have to have another process for retrying jobs that couldn't be started. """ @@ -271,9 +256,8 @@ def _create_failed_job(self, stream_slice: StreamSlice) -> AsyncJob: job.update_status(AsyncJobStatus.FAILED) return job - def _get_running_jobs(self) -> Set[AsyncJob]: - """ - Returns a set of running AsyncJob objects. + def _get_running_jobs(self) -> set[AsyncJob]: + """Returns a set of running AsyncJob objects. Returns: Set[AsyncJob]: A set of AsyncJob objects that are currently running. @@ -286,18 +270,14 @@ def _get_running_jobs(self) -> Set[AsyncJob]: } def _update_jobs_status(self) -> None: - """ - Update the status of all running jobs in the repository. - """ + """Update the status of all running jobs in the repository.""" running_jobs = self._get_running_jobs() if running_jobs: # update the status only if there are RUNNING jobs self._job_repository.update_jobs_status(running_jobs) def _wait_on_status_update(self) -> None: - """ - Waits for a specified amount of time between status updates. - + """Waits for a specified amount of time between status updates. This method is used to introduce a delay between status updates in order to avoid excessive polling. The duration of the delay is determined by the value of `_WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS`. @@ -319,8 +299,8 @@ def _wait_on_status_update(self) -> None: time.sleep(self._WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS) def _process_completed_partition(self, partition: AsyncPartition) -> None: - """ - Process a completed partition. + """Process a completed partition. + Args: partition (AsyncPartition): The completed partition to process. """ @@ -337,8 +317,7 @@ def _process_completed_partition(self, partition: AsyncPartition) -> None: def _process_running_partitions_and_yield_completed_ones( self, ) -> Generator[AsyncPartition, Any, None]: - """ - Process the running partitions. + """Process the running partitions. Yields: AsyncPartition: The processed partition. @@ -346,7 +325,7 @@ def _process_running_partitions_and_yield_completed_ones( Raises: Any: Any exception raised during processing. """ - current_running_partitions: List[AsyncPartition] = [] + current_running_partitions: list[AsyncPartition] = [] for partition in self._running_partitions: match partition.status: case AsyncJobStatus.COMPLETED: @@ -393,13 +372,14 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: LOGGER.warning(f"Could not free budget for job {job.api_job_id()}: {exception}") def _process_partitions_with_errors(self, partition: AsyncPartition) -> None: - """ - Process a partition with status errors (FAILED and TIMEOUT). + """Process a partition with status errors (FAILED and TIMEOUT). Args: partition (AsyncPartition): The partition to process. + Returns: AirbyteTracedException: An exception indicating that at least one job could not be completed. + Raises: AirbyteTracedException: If at least one job could not be completed. """ @@ -412,8 +392,7 @@ def _process_partitions_with_errors(self, partition: AsyncPartition) -> None: ) def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: - """ - Creates and retrieves completed partitions. + """Creates and retrieves completed partitions. This method continuously starts jobs, updates job status, processes running partitions, logs polling partitions, and waits for status updates. It yields completed partitions as they become available. @@ -483,8 +462,7 @@ def _is_breaking_exception(self, exception: Exception) -> bool: ) def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: - """ - Fetches records from the given partition's jobs. + """Fetches records from the given partition's jobs. Args: partition (AsyncPartition): The partition containing the jobs. diff --git a/airbyte_cdk/sources/declarative/async_job/job_tracker.py b/airbyte_cdk/sources/declarative/async_job/job_tracker.py index b47fc4ca..aff54038 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_tracker.py +++ b/airbyte_cdk/sources/declarative/async_job/job_tracker.py @@ -1,12 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import logging import threading import uuid -from typing import Set from airbyte_cdk.logger import lazy_log + LOGGER = logging.getLogger("airbyte") @@ -16,7 +17,7 @@ class ConcurrentJobLimitReached(Exception): class JobTracker: def __init__(self, limit: int): - self._jobs: Set[str] = set() + self._jobs: set[str] = set() self._limit = limit self._lock = threading.Lock() @@ -31,7 +32,7 @@ def try_to_get_intent(self) -> str: raise ConcurrentJobLimitReached( "Can't allocate more jobs right now: limit already reached" ) - intent = f"intent_{str(uuid.uuid4())}" + intent = f"intent_{uuid.uuid4()!s}" lazy_log( LOGGER, logging.DEBUG, @@ -60,9 +61,7 @@ def add_job(self, intent_or_job_id: str, job_id: str) -> None: self._jobs.remove(intent_or_job_id) def remove_job(self, job_id: str) -> None: - """ - If the job is not allocated as a running job, this method does nothing and it won't raise. - """ + """If the job is not allocated as a running job, this method does nothing and it won't raise.""" lazy_log( LOGGER, logging.DEBUG, diff --git a/airbyte_cdk/sources/declarative/async_job/repository.py b/airbyte_cdk/sources/declarative/async_job/repository.py index 21581ec4..ebd1a92d 100644 --- a/airbyte_cdk/sources/declarative/async_job/repository.py +++ b/airbyte_cdk/sources/declarative/async_job/repository.py @@ -1,7 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import abstractmethod -from typing import Any, Iterable, Mapping, Set +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.types import StreamSlice @@ -13,7 +15,7 @@ def start(self, stream_slice: StreamSlice) -> AsyncJob: pass @abstractmethod - def update_jobs_status(self, jobs: Set[AsyncJob]) -> None: + def update_jobs_status(self, jobs: set[AsyncJob]) -> None: pass @abstractmethod @@ -22,8 +24,7 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: @abstractmethod def abort(self, job: AsyncJob) -> None: - """ - Called when we need to stop on the API side. This method can raise NotImplementedError as not all the APIs will support aborting + """Called when we need to stop on the API side. This method can raise NotImplementedError as not all the APIs will support aborting jobs. """ raise NotImplementedError( diff --git a/airbyte_cdk/sources/declarative/async_job/status.py b/airbyte_cdk/sources/declarative/async_job/status.py index 586e7988..505af3d9 100644 --- a/airbyte_cdk/sources/declarative/async_job/status.py +++ b/airbyte_cdk/sources/declarative/async_job/status.py @@ -1,8 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - +from __future__ import annotations from enum import Enum + _TERMINAL = True @@ -17,8 +18,7 @@ def __init__(self, value: str, is_terminal: bool) -> None: self._is_terminal = is_terminal def is_terminal(self) -> bool: - """ - A status is terminal when a job status can't be updated anymore. For example if a job is completed, it will stay completed but a + """A status is terminal when a job status can't be updated anymore. For example if a job is completed, it will stay completed but a running job might because completed, failed or timed out. """ return self._is_terminal diff --git a/airbyte_cdk/sources/declarative/async_job/timer.py b/airbyte_cdk/sources/declarative/async_job/timer.py index c4e5a9a1..22092425 100644 --- a/airbyte_cdk/sources/declarative/async_job/timer.py +++ b/airbyte_cdk/sources/declarative/async_job/timer.py @@ -1,12 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations + from datetime import datetime, timedelta, timezone -from typing import Optional class Timer: def __init__(self, timeout: timedelta) -> None: - self._start_datetime: Optional[datetime] = None - self._end_datetime: Optional[datetime] = None + self._start_datetime: datetime | None = None + self._end_datetime: datetime | None = None self._timeout = timeout def start(self) -> None: @@ -21,7 +22,7 @@ def is_started(self) -> bool: return self._start_datetime is not None @property - def elapsed_time(self) -> Optional[timedelta]: + def elapsed_time(self) -> timedelta | None: if not self._start_datetime: return None diff --git a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py index b749718f..e1867412 100644 --- a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( AbstractHeaderAuthenticator, @@ -12,15 +14,13 @@ @dataclass class DeclarativeAuthenticator(AbstractHeaderAuthenticator): - """ - Interface used to associate which authenticators can be used as part of the declarative framework - """ + """Interface used to associate which authenticators can be used as part of the declarative framework""" def get_request_params(self) -> Mapping[str, Any]: """HTTP request parameter to add to the requests""" return {} - def get_request_body_data(self) -> Union[Mapping[str, Any], str]: + def get_request_body_data(self) -> Mapping[str, Any] | str: """Form-encoded body data to set on the requests""" return {} diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py index 4095635d..669dfa63 100644 --- a/airbyte_cdk/sources/declarative/auth/jwt.py +++ b/airbyte_cdk/sources/declarative/auth/jwt.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import base64 +from collections.abc import Mapping from dataclasses import InitVar, dataclass from datetime import datetime -from typing import Any, Mapping, Optional, Union +from typing import Any import jwt + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping @@ -15,9 +18,7 @@ class JwtAlgorithm(str): - """ - Enum for supported JWT algorithms - """ + """Enum for supported JWT algorithms""" HS256 = "HS256" HS384 = "HS384" @@ -37,8 +38,7 @@ class JwtAlgorithm(str): @dataclass class JwtAuthenticator(DeclarativeAuthenticator): - """ - Generates a JSON Web Token (JWT) based on a declarative connector configuration file. The generated token is attached to each request via the Authorization header. + """Generates a JSON Web Token (JWT) based on a declarative connector configuration file. The generated token is attached to each request via the Authorization header. Attributes: config (Mapping[str, Any]): The user-provided configuration as specified by the source's spec @@ -59,19 +59,19 @@ class JwtAuthenticator(DeclarativeAuthenticator): config: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - secret_key: Union[InterpolatedString, str] - algorithm: Union[str, JwtAlgorithm] - token_duration: Optional[int] - base64_encode_secret_key: Optional[Union[InterpolatedBoolean, str, bool]] = False - header_prefix: Optional[Union[InterpolatedString, str]] = None - kid: Optional[Union[InterpolatedString, str]] = None - typ: Optional[Union[InterpolatedString, str]] = None - cty: Optional[Union[InterpolatedString, str]] = None - iss: Optional[Union[InterpolatedString, str]] = None - sub: Optional[Union[InterpolatedString, str]] = None - aud: Optional[Union[InterpolatedString, str]] = None - additional_jwt_headers: Optional[Mapping[str, Any]] = None - additional_jwt_payload: Optional[Mapping[str, Any]] = None + secret_key: InterpolatedString | str + algorithm: str | JwtAlgorithm + token_duration: int | None + base64_encode_secret_key: InterpolatedBoolean | str | bool | None = False + header_prefix: InterpolatedString | str | None = None + kid: InterpolatedString | str | None = None + typ: InterpolatedString | str | None = None + cty: InterpolatedString | str | None = None + iss: InterpolatedString | str | None = None + sub: InterpolatedString | str | None = None + aud: InterpolatedString | str | None = None + additional_jwt_headers: Mapping[str, Any] | None = None + additional_jwt_payload: Mapping[str, Any] | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters) @@ -122,9 +122,7 @@ def _get_jwt_headers(self) -> dict[str, Any]: return headers def _get_jwt_payload(self) -> dict[str, Any]: - """ - Builds and returns the payload used when signing the JWT. - """ + """Builds and returns the payload used when signing the JWT.""" now = int(datetime.now().timestamp()) exp = now + self._token_duration if isinstance(self._token_duration, int) else now nbf = now @@ -147,9 +145,7 @@ def _get_jwt_payload(self) -> dict[str, Any]: return payload def _get_secret_key(self) -> str: - """ - Returns the secret key used to sign the JWT. - """ + """Returns the secret key used to sign the JWT.""" secret_key: str = self._secret_key.eval(self.config) return ( base64.b64encode(secret_key.encode()).decode() @@ -157,10 +153,8 @@ def _get_secret_key(self) -> str: else secret_key ) - def _get_signed_token(self) -> Union[str, Any]: - """ - Signed the JWT using the provided secret key and algorithm and the generated headers and payload. For additional information on PyJWT see: https://pyjwt.readthedocs.io/en/stable/ - """ + def _get_signed_token(self) -> str | Any: + """Signed the JWT using the provided secret key and algorithm and the generated headers and payload. For additional information on PyJWT see: https://pyjwt.readthedocs.io/en/stable/""" try: return jwt.encode( payload=self._get_jwt_payload(), @@ -171,10 +165,8 @@ def _get_signed_token(self) -> Union[str, Any]: except Exception as e: raise ValueError(f"Failed to sign token: {e}") - def _get_header_prefix(self) -> Union[str, None]: - """ - Returns the header prefix to be used when attaching the token to the request. - """ + def _get_header_prefix(self) -> str | None: + """Returns the header prefix to be used when attaching the token to the request.""" return self._header_prefix.eval(self.config) if self._header_prefix else None @property diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 773d2818..90210165 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, Optional, Union +from typing import Any import pendulum + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -20,8 +23,7 @@ @dataclass class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAuthenticator): - """ - Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on + """Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on a declarative connector configuration file. Credentials can be defined explicitly or via interpolation at runtime. The generated access token is attached to each request via the Authorization header. @@ -42,21 +44,21 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut message_repository (MessageRepository): the message repository used to emit logs on HTTP requests """ - token_refresh_endpoint: Union[InterpolatedString, str] - client_id: Union[InterpolatedString, str] - client_secret: Union[InterpolatedString, str] + token_refresh_endpoint: InterpolatedString | str + client_id: InterpolatedString | str + client_secret: InterpolatedString | str config: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - refresh_token: Optional[Union[InterpolatedString, str]] = None - scopes: Optional[List[str]] = None - token_expiry_date: Optional[Union[InterpolatedString, str]] = None - _token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None) - token_expiry_date_format: Optional[str] = None + refresh_token: InterpolatedString | str | None = None + scopes: list[str] | None = None + token_expiry_date: InterpolatedString | str | None = None + _token_expiry_date: pendulum.DateTime | None = field(init=False, repr=False, default=None) + token_expiry_date_format: str | None = None token_expiry_is_time_of_expiration: bool = False - access_token_name: Union[InterpolatedString, str] = "access_token" - expires_in_name: Union[InterpolatedString, str] = "expires_in" - refresh_request_body: Optional[Mapping[str, Any]] = None - grant_type: Union[InterpolatedString, str] = "refresh_token" + access_token_name: InterpolatedString | str = "access_token" + expires_in_name: InterpolatedString | str = "expires_in" + refresh_request_body: Mapping[str, Any] | None = None + grant_type: InterpolatedString | str = "refresh_token" message_repository: MessageRepository = NoopMessageRepository() def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -67,7 +69,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._client_id = InterpolatedString.create(self.client_id, parameters=parameters) self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters) if self.refresh_token is not None: - self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create( + self._refresh_token: InterpolatedString | None = InterpolatedString.create( self.refresh_token, parameters=parameters ) else: @@ -91,7 +93,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.token_expiry_date else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints ) - self._access_token: Optional[str] = None # access_token is initialized by a setter + self._access_token: str | None = None # access_token is initialized by a setter if self.get_grant_type() == "refresh_token" and self._refresh_token is None: raise ValueError( @@ -118,10 +120,10 @@ def get_client_secret(self) -> str: raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter") return client_secret - def get_refresh_token(self) -> Optional[str]: + def get_refresh_token(self) -> str | None: return None if self._refresh_token is None else str(self._refresh_token.eval(self.config)) - def get_scopes(self) -> List[str]: + def get_scopes(self) -> list[str]: return self.scopes or [] def get_access_token_name(self) -> str: @@ -139,7 +141,7 @@ def get_refresh_request_body(self) -> Mapping[str, Any]: def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: str | int) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) @property @@ -154,9 +156,7 @@ def access_token(self, value: str) -> None: @property def _message_repository(self) -> MessageRepository: - """ - Overriding AbstractOauth2Authenticator._message_repository to allow for HTTP request logs - """ + """Overriding AbstractOauth2Authenticator._message_repository to allow for HTTP request logs""" return self.message_repository @@ -164,9 +164,7 @@ def _message_repository(self) -> MessageRepository: class DeclarativeSingleUseRefreshTokenOauth2Authenticator( SingleUseRefreshTokenOauth2Authenticator, DeclarativeAuthenticator ): - """ - Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors. - """ + """Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py index 11a2ae7d..0769c33e 100644 --- a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, List, Mapping +from typing import Any import dpath + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator @@ -15,14 +18,14 @@ class SelectiveAuthenticator(DeclarativeAuthenticator): config: Mapping[str, Any] authenticators: Mapping[str, DeclarativeAuthenticator] - authenticator_selection_path: List[str] + authenticator_selection_path: list[str] # returns "DeclarativeAuthenticator", but must return a subtype of "SelectiveAuthenticator" def __new__( # type: ignore[misc] cls, config: Mapping[str, Any], authenticators: Mapping[str, DeclarativeAuthenticator], - authenticator_selection_path: List[str], + authenticator_selection_path: list[str], *arg: Any, **kwargs: Any, ) -> DeclarativeAuthenticator: diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py index dc35eb45..ae6b091d 100644 --- a/airbyte_cdk/sources/declarative/auth/token.py +++ b/airbyte_cdk/sources/declarative/auth/token.py @@ -1,13 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import base64 import logging +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any import requests +from cachetools import TTLCache, cached + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.auth.token_provider import TokenProvider from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -16,13 +20,11 @@ RequestOptionType, ) from airbyte_cdk.sources.types import Config -from cachetools import TTLCache, cached @dataclass class ApiKeyAuthenticator(DeclarativeAuthenticator): - """ - ApiKeyAuth sets a request header on the HTTP requests sent. + """ApiKeyAuth sets a request header on the HTTP requests sent. The header is of the form: `"
": ""` @@ -67,7 +69,7 @@ def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, A def get_request_params(self) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter) - def get_request_body_data(self) -> Union[Mapping[str, Any], str]: + def get_request_body_data(self) -> Mapping[str, Any] | str: return self._get_request_options(RequestOptionType.body_data) def get_request_body_json(self) -> Mapping[str, Any]: @@ -76,8 +78,7 @@ def get_request_body_json(self) -> Mapping[str, Any]: @dataclass class BearerAuthenticator(DeclarativeAuthenticator): - """ - Authenticator that sets the Authorization header on the HTTP requests sent. + """Authenticator that sets the Authorization header on the HTTP requests sent. The header is of the form: `"Authorization": "Bearer "` @@ -103,8 +104,7 @@ def token(self) -> str: @dataclass class BasicHttpAuthenticator(DeclarativeAuthenticator): - """ - Builds auth based off the basic authentication scheme as defined by RFC 7617, which transmits credentials as USER ID/password pairs, encoded using base64 + """Builds auth based off the basic authentication scheme as defined by RFC 7617, which transmits credentials as USER ID/password pairs, encoded using base64 https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication#basic_authentication_scheme The header is of the form @@ -117,10 +117,10 @@ class BasicHttpAuthenticator(DeclarativeAuthenticator): parameters (Mapping[str, Any]): Additional runtime parameters to be used for string interpolation """ - username: Union[InterpolatedString, str] + username: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - password: Union[InterpolatedString, str] = "" + password: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._username = InterpolatedString.create(self.username, parameters=parameters) @@ -133,7 +133,7 @@ def auth_header(self) -> str: @property def token(self) -> str: auth_string = ( - f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode("utf8") + f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode() ) b64_encoded = base64.b64encode(auth_string).decode("utf8") return f"Basic {b64_encoded}" @@ -152,9 +152,9 @@ def token(self) -> str: @cached(cacheSessionTokenAuthenticator) def get_new_session_token(api_url: str, username: str, password: str, response_key: str) -> str: - """ - This method retrieves session token from api by username and password for SessionTokenAuthenticator. + """This method retrieves session token from api by username and password for SessionTokenAuthenticator. It's cashed to avoid a multiple calling by sync and updating session token every stream sync. + Args: api_url: api url for getting new session token username: username for auth @@ -179,8 +179,7 @@ def get_new_session_token(api_url: str, username: str, password: str, response_k @dataclass class LegacySessionTokenAuthenticator(DeclarativeAuthenticator): - """ - Builds auth based on session tokens. + """Builds auth based on session tokens. A session token is a random value generated by a server to identify a specific user for the duration of one interaction session. @@ -200,16 +199,16 @@ class LegacySessionTokenAuthenticator(DeclarativeAuthenticator): validate_session_url (Union[InterpolatedString, str]): Url to validate passed session token """ - api_url: Union[InterpolatedString, str] - header: Union[InterpolatedString, str] - session_token: Union[InterpolatedString, str] - session_token_response_key: Union[InterpolatedString, str] - username: Union[InterpolatedString, str] + api_url: InterpolatedString | str + header: InterpolatedString | str + session_token: InterpolatedString | str + session_token_response_key: InterpolatedString | str + username: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - login_url: Union[InterpolatedString, str] - validate_session_url: Union[InterpolatedString, str] - password: Union[InterpolatedString, str] = "" + login_url: InterpolatedString | str + validate_session_url: InterpolatedString | str + password: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._username = InterpolatedString.create(self.username, parameters=parameters) @@ -258,14 +257,12 @@ def is_valid_session_token(self) -> bool: response.raise_for_status() except requests.exceptions.HTTPError as e: if e.response.status_code == requests.codes["unauthorized"]: - self.logger.info(f"Unable to connect by session token from config due to {str(e)}") + self.logger.info(f"Unable to connect by session token from config due to {e!s}") return False - else: - raise ConnectionError(f"Error while validating session token: {e}") + raise ConnectionError(f"Error while validating session token: {e}") if response.ok: self.logger.info("Connection check for source is successful.") return True - else: - raise ConnectionError( - f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" - ) + raise ConnectionError( + f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" + ) diff --git a/airbyte_cdk/sources/declarative/auth/token_provider.py b/airbyte_cdk/sources/declarative/auth/token_provider.py index c3c2a41f..facb3d63 100644 --- a/airbyte_cdk/sources/declarative/auth/token_provider.py +++ b/airbyte_cdk/sources/declarative/auth/token_provider.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import datetime from abc import abstractmethod +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, Optional, Union +from typing import Any import dpath import pendulum +from isodate import Duration +from pendulum import DateTime + from airbyte_cdk.sources.declarative.decoders.decoder import Decoder from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder from airbyte_cdk.sources.declarative.exceptions import ReadException @@ -18,8 +22,6 @@ from airbyte_cdk.sources.http_logger import format_http_message from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository from airbyte_cdk.sources.types import Config -from isodate import Duration -from pendulum import DateTime class TokenProvider: @@ -31,14 +33,14 @@ def get_token(self) -> str: @dataclass class SessionTokenProvider(TokenProvider): login_requester: Requester - session_token_path: List[str] - expiration_duration: Optional[Union[datetime.timedelta, Duration]] + session_token_path: list[str] + expiration_duration: datetime.timedelta | Duration | None parameters: InitVar[Mapping[str, Any]] message_repository: MessageRepository = NoopMessageRepository() decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={})) - _next_expiration_time: Optional[DateTime] = None - _token: Optional[str] = None + _next_expiration_time: DateTime | None = None + _token: str | None = None def get_token(self) -> str: self._refresh_if_necessary() @@ -71,7 +73,7 @@ def _refresh(self) -> None: @dataclass class InterpolatedStringTokenProvider(TokenProvider): config: Config - api_token: Union[InterpolatedString, str] + api_token: InterpolatedString | str parameters: Mapping[str, Any] def __post_init__(self) -> None: diff --git a/airbyte_cdk/sources/declarative/checks/check_stream.py b/airbyte_cdk/sources/declarative/checks/check_stream.py index c45159ec..efe8e9d7 100644 --- a/airbyte_cdk/sources/declarative/checks/check_stream.py +++ b/airbyte_cdk/sources/declarative/checks/check_stream.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import traceback +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, List, Mapping, Tuple +from typing import Any from airbyte_cdk import AbstractSource from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker @@ -14,14 +16,13 @@ @dataclass class CheckStream(ConnectionChecker): - """ - Checks the connections by checking availability of one or many streams selected by the developer + """Checks the connections by checking availability of one or many streams selected by the developer Attributes: stream_name (List[str]): names of streams to check """ - stream_names: List[str] + stream_names: list[str] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -29,13 +30,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def check_connection( self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: streams = source.streams(config=config) stream_name_to_stream = {s.name: s for s in streams} if len(streams) == 0: return False, f"No streams to connect to from source {source}" for stream_name in self.stream_names: - if stream_name not in stream_name_to_stream.keys(): + if stream_name not in stream_name_to_stream: raise ValueError( f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}." ) diff --git a/airbyte_cdk/sources/declarative/checks/connection_checker.py b/airbyte_cdk/sources/declarative/checks/connection_checker.py index fd1d1bba..23edafac 100644 --- a/airbyte_cdk/sources/declarative/checks/connection_checker.py +++ b/airbyte_cdk/sources/declarative/checks/connection_checker.py @@ -1,25 +1,24 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any, Mapping, Tuple +from collections.abc import Mapping +from typing import Any from airbyte_cdk import AbstractSource class ConnectionChecker(ABC): - """ - Abstract base class for checking a connection - """ + """Abstract base class for checking a connection""" @abstractmethod def check_connection( self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: - """ - Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect + ) -> tuple[bool, Any]: + """Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect to the Stripe API. :param source: source diff --git a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py index f5cd24f0..504b6435 100644 --- a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py +++ b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py @@ -1,9 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.types import Config @@ -11,22 +13,21 @@ @dataclass class ConcurrencyLevel: - """ - Returns the number of worker threads that should be used when syncing concurrent streams in parallel + """Returns the number of worker threads that should be used when syncing concurrent streams in parallel Attributes: default_concurrency (Union[int, str]): The hardcoded integer or interpolation of how many worker threads to use during a sync max_concurrency (Optional[int]): The maximum number of worker threads to use when the default_concurrency is exceeded """ - default_concurrency: Union[int, str] - max_concurrency: Optional[int] + default_concurrency: int | str + max_concurrency: int | None config: Config parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.default_concurrency, int): - self._default_concurrency: Union[int, InterpolatedString] = self.default_concurrency + self._default_concurrency: int | InterpolatedString = self.default_concurrency elif "config" in self.default_concurrency and not self.max_concurrency: raise ValueError( "ConcurrencyLevel requires that max_concurrency be defined if the default_concurrency can be used-specified" @@ -46,5 +47,4 @@ def get_concurrency_level(self) -> int: if self.max_concurrency else evaluated_default_concurrency ) - else: - return self._default_concurrency + return self._default_concurrency diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 62e0b578..ff02ed14 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -1,9 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging -from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterator, Mapping +from typing import Any, Generic from airbyte_cdk.models import ( AirbyteCatalog, @@ -49,13 +51,13 @@ class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): def __init__( self, - catalog: Optional[ConfiguredAirbyteCatalog], - config: Optional[Mapping[str, Any]], + catalog: ConfiguredAirbyteCatalog | None, + config: Mapping[str, Any] | None, state: TState, source_config: ConnectionDefinition, debug: bool = False, emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, + component_factory: ModelToComponentFactory | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -67,8 +69,8 @@ def __init__( self._state = state - self._concurrent_streams: Optional[List[AbstractStream]] - self._synchronous_streams: Optional[List[Stream]] + self._concurrent_streams: list[AbstractStream] | None + self._synchronous_streams: list[Stream] | None # If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because # they might depend on it. Ideally we want to have a static method on this class to get the spec without @@ -115,7 +117,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[Union[List[AirbyteStateMessage]]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: # ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent # streams must be saved so that they can be removed from the catalog before starting synchronous streams @@ -152,9 +154,8 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> Airbyte ] ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: - """ - The `streams` method is used as part of the AbstractSource in the following cases: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: + """The `streams` method is used as part of the AbstractSource in the following cases: * ConcurrentDeclarativeSource.check -> ManifestDeclarativeSource.check -> AbstractSource.check -> DeclarativeSource.check_connection -> CheckStream.check_connection -> streams * ConcurrentDeclarativeSource.read -> AbstractSource.read -> streams (note that we filter for a specific catalog which excludes concurrent streams so not all streams actually read from all the streams returned by `streams`) Note that `super.streams(config)` is also called when splitting the streams between concurrent or not in `_group_streams`. @@ -165,9 +166,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: def _group_streams( self, config: Mapping[str, Any] - ) -> Tuple[List[AbstractStream], List[Stream]]: - concurrent_streams: List[AbstractStream] = [] - synchronous_streams: List[Stream] = [] + ) -> tuple[list[AbstractStream], list[Stream]]: + concurrent_streams: list[AbstractStream] = [] + synchronous_streams: list[Stream] = [] state_manager = ConnectorStateManager(state=self._state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later @@ -259,8 +260,7 @@ def _group_streams( def _stream_supports_concurrent_partition_processing( self, declarative_stream: DeclarativeStream ) -> bool: - """ - Many connectors make use of stream_state during interpolation on a per-partition basis under the assumption that + """Many connectors make use of stream_state during interpolation on a per-partition basis under the assumption that state is updated sequentially. Because the concurrent CDK engine processes different partitions in parallel, stream_state is no longer a thread-safe interpolation context. It would be a race condition because a cursor's stream_state can be updated in any order depending on which stream partition's finish first. @@ -269,7 +269,6 @@ def _stream_supports_concurrent_partition_processing( per-partition, but we need to gate this otherwise some connectors will be blocked from publishing. See the cdk-migrations.md for the full list of connectors. """ - if isinstance(declarative_stream.retriever, SimpleRetriever) and isinstance( declarative_stream.retriever.requester, HttpRequester ): @@ -321,10 +320,10 @@ def _stream_supports_concurrent_partition_processing( @staticmethod def _select_streams( - streams: List[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog - ) -> List[AbstractStream]: + streams: list[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog + ) -> list[AbstractStream]: stream_name_to_instance: Mapping[str, AbstractStream] = {s.name: s for s in streams} - abstract_streams: List[AbstractStream] = [] + abstract_streams: list[AbstractStream] = [] for configured_stream in configured_catalog.streams: stream_instance = stream_name_to_instance.get(configured_stream.stream.name) if stream_instance: diff --git a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py index 93122e29..3a632951 100644 --- a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py +++ b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py @@ -1,14 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime -from typing import Union class DatetimeParser: - """ - Parses and formats datetime objects according to a specified format. + """Parses and formats datetime objects according to a specified format. This class mainly acts as a wrapper to properly handling timestamp formatting through the "%s" directive. @@ -18,7 +17,7 @@ class DatetimeParser: _UNIX_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) - def parse(self, date: Union[str, int], format: str) -> datetime.datetime: + def parse(self, date: str | int, format: str) -> datetime.datetime: # "%s" is a valid (but unreliable) directive for formatting, but not for parsing # It is defined as # The number of seconds since the Epoch, 1970-01-01 00:00:00+0000 (UTC). https://man7.org/linux/man-pages/man3/strptime.3.html @@ -27,9 +26,9 @@ def parse(self, date: Union[str, int], format: str) -> datetime.datetime: # See https://stackoverflow.com/a/4974930 if format == "%s": return datetime.datetime.fromtimestamp(int(date), tz=datetime.timezone.utc) - elif format == "%s_as_float": + if format == "%s_as_float": return datetime.datetime.fromtimestamp(float(date), tz=datetime.timezone.utc) - elif format == "%ms": + if format == "%ms": return self._UNIX_EPOCH + datetime.timedelta(milliseconds=int(date)) parsed_datetime = datetime.datetime.strptime(str(date), format) @@ -48,8 +47,7 @@ def format(self, dt: datetime.datetime, format: str) -> str: if format == "%ms": # timstamp() returns a float representing the number of seconds since the unix epoch return str(int(dt.timestamp() * 1000)) - else: - return dt.strftime(format) + return dt.strftime(format) def _is_naive(self, dt: datetime.datetime) -> bool: return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None diff --git a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py index 1edf9243..95029301 100644 --- a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py +++ b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime as dt +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -12,8 +14,7 @@ @dataclass class MinMaxDatetime: - """ - Compares the provided date against optional minimum or maximum times. If date is earlier than + """Compares the provided date against optional minimum or maximum times. If date is earlier than min_date, then min_date is returned. If date is greater than max_date, then max_date is returned. If neither, the input date is returned. @@ -28,14 +29,14 @@ class MinMaxDatetime: max_datetime (Union[InterpolatedString, str]): Represents the maximum allowed datetime value. """ - datetime: Union[InterpolatedString, str] + datetime: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] # datetime_format is a unique case where we inherit it from the parent if it is not specified before using the default value # which is why we need dedicated getter/setter methods and private dataclass field datetime_format: str _datetime_format: str = field(init=False, repr=False, default="") - min_datetime: Union[InterpolatedString, str] = "" - max_datetime: Union[InterpolatedString, str] = "" + min_datetime: InterpolatedString | str = "" + max_datetime: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {}) @@ -54,8 +55,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_datetime( self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any] ) -> dt.datetime: - """ - Evaluates and returns the datetime + """Evaluates and returns the datetime :param config: The user-provided configuration as specified by the source's spec :param additional_parameters: Additional arguments to be passed to the strings for interpolation :return: The evaluated datetime @@ -97,9 +97,9 @@ def datetime_format(self, value: str) -> None: @classmethod def create( cls, - interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"], - parameters: Optional[Mapping[str, Any]] = None, - ) -> "MinMaxDatetime": + interpolated_string_or_min_max_datetime: InterpolatedString | str | MinMaxDatetime, + parameters: Mapping[str, Any] | None = None, + ) -> MinMaxDatetime: if parameters is None: parameters = {} if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( @@ -108,5 +108,4 @@ def create( return MinMaxDatetime( datetime=interpolated_string_or_min_max_datetime, parameters=parameters ) - else: - return interpolated_string_or_min_max_datetime + return interpolated_string_or_min_max_datetime diff --git a/airbyte_cdk/sources/declarative/declarative_source.py b/airbyte_cdk/sources/declarative/declarative_source.py index 77bf427a..27eadf13 100644 --- a/airbyte_cdk/sources/declarative/declarative_source.py +++ b/airbyte_cdk/sources/declarative/declarative_source.py @@ -1,19 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import abstractmethod -from typing import Any, Mapping, Tuple +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker class DeclarativeSource(AbstractSource): - """ - Base class for declarative Source. Concrete sources need to define the connection_checker to use - """ + """Base class for declarative Source. Concrete sources need to define the connection_checker to use""" @property @abstractmethod @@ -22,9 +22,8 @@ def connection_checker(self) -> ConnectionChecker: def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: - """ - :param logger: The source logger + ) -> tuple[bool, Any]: + """:param logger: The source logger :param config: The user-provided configuration as specified by the source's spec. This usually contains information required to check connection e.g. tokens, secrets and keys etc. :return: A tuple of (boolean, error). If boolean is true, then the connection check is successful diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py index 12cdd333..bb3ccad1 100644 --- a/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte_cdk/sources/declarative/declarative_stream.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging +from collections.abc import Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.declarative.incremental import ( @@ -29,8 +32,7 @@ @dataclass class DeclarativeStream(Stream): - """ - DeclarativeStream is a Stream that delegates most of its logic to its schema_load and retriever + """DeclarativeStream is a Stream that delegates most of its logic to its schema_load and retriever Attributes: name (str): stream name @@ -46,12 +48,12 @@ class DeclarativeStream(Stream): config: Config parameters: InitVar[Mapping[str, Any]] name: str - primary_key: Optional[Union[str, List[str], List[List[str]]]] - state_migrations: List[StateMigration] = field(repr=True, default_factory=list) - schema_loader: Optional[SchemaLoader] = None + primary_key: str | list[str] | list[list[str]] | None + state_migrations: list[StateMigration] = field(repr=True, default_factory=list) + schema_loader: SchemaLoader | None = None _name: str = field(init=False, repr=False, default="") _primary_key: str = field(init=False, repr=False, default="") - stream_cursor_field: Optional[Union[InterpolatedString, str]] = None + stream_cursor_field: InterpolatedString | str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._stream_cursor_field = ( @@ -59,14 +61,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.stream_cursor_field, str) else self.stream_cursor_field ) - self._schema_loader = ( - self.schema_loader - if self.schema_loader - else DefaultSchemaLoader(config=self.config, parameters=parameters) + self._schema_loader = self.schema_loader or DefaultSchemaLoader( + config=self.config, parameters=parameters ) @property # type: ignore - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return self._primary_key @primary_key.setter @@ -84,9 +84,7 @@ def exit_on_rate_limit(self, value: bool) -> None: @property # type: ignore def name(self) -> str: - """ - :return: Stream name. By default this is the implementing class name, but it can be overridden as needed. - """ + """:return: Stream name. By default this is the implementing class name, but it can be overridden as needed.""" return self._name @name.setter @@ -114,13 +112,12 @@ def get_updated_state( return self.state @property - def cursor_field(self) -> Union[str, List[str]]: - """ - Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. + def cursor_field(self) -> str | list[str]: + """Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ cursor = self._stream_cursor_field.eval(self.config) # type: ignore # _stream_cursor_field is always cast to interpolated string - return cursor if cursor else [] + return cursor or [] @property def is_resumable(self) -> bool: @@ -131,13 +128,11 @@ def is_resumable(self) -> bool: def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: - """ - :param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state. - """ + """:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.""" if stream_slice is None or stream_slice == {}: # As the parameter is Optional, many would just call `read_records(sync_mode)` during testing without specifying the field # As part of the declarative model without custom components, this should never happen as the CDK would wire up a @@ -152,8 +147,7 @@ def read_records( yield from self.retriever.read_records(self.get_json_schema(), stream_slice) # type: ignore # records are of the correct type def get_json_schema(self) -> Mapping[str, Any]: # type: ignore - """ - :return: A dict of the JSON schema representing this stream. + """:return: A dict of the JSON schema representing this stream. The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property. Override as needed. @@ -164,11 +158,10 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[StreamSlice]]: - """ - Override to define the slices for this stream. See the stream slicing section of the docs for more information. + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[StreamSlice | None]: + """Override to define the slices for this stream. See the stream slicing section of the docs for more information. :param sync_mode: :param cursor_field: @@ -178,9 +171,8 @@ def stream_slices( return self.retriever.stream_slices() @property - def state_checkpoint_interval(self) -> Optional[int]: - """ - We explicitly disable checkpointing here. There are a couple reasons for that and not all are documented here but: + def state_checkpoint_interval(self) -> int | None: + """We explicitly disable checkpointing here. There are a couple reasons for that and not all are documented here but: * In the case where records are not ordered, the granularity of what is ordered is the slice. Therefore, we will only update the cursor value once at the end of every slice. * Updating the state once every record would generate issues for data feed stop conditions or semi-incremental syncs where the @@ -188,7 +180,7 @@ def state_checkpoint_interval(self) -> Optional[int]: """ return None - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: if self.retriever and isinstance(self.retriever, SimpleRetriever): return self.retriever.cursor return None @@ -196,12 +188,11 @@ def get_cursor(self) -> Optional[Cursor]: def _get_checkpoint_reader( self, logger: logging.Logger, - cursor_field: Optional[List[str]], + cursor_field: list[str] | None, sync_mode: SyncMode, stream_state: MutableMapping[str, Any], ) -> CheckpointReader: - """ - This method is overridden to prevent issues with stream slice classification for incremental streams that have parent streams. + """This method is overridden to prevent issues with stream slice classification for incremental streams that have parent streams. The classification logic, when used with `itertools.tee`, creates a copy of the stream slices. When `stream_slices` is called the second time, the parent records generated during the classification phase are lost. This occurs because `itertools.tee` @@ -212,7 +203,7 @@ def _get_checkpoint_reader( """ mappings_or_slices = self.stream_slices( cursor_field=cursor_field, - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_state=stream_state, ) diff --git a/airbyte_cdk/sources/declarative/decoders/decoder.py b/airbyte_cdk/sources/declarative/decoders/decoder.py index 5fa9dc8f..e0930b7e 100644 --- a/airbyte_cdk/sources/declarative/decoders/decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/decoder.py @@ -1,32 +1,29 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Generator, MutableMapping from dataclasses import dataclass -from typing import Any, Generator, MutableMapping +from typing import Any import requests @dataclass class Decoder: - """ - Decoder strategy to transform a requests.Response into a Mapping[str, Any] - """ + """Decoder strategy to transform a requests.Response into a Mapping[str, Any]""" @abstractmethod def is_stream_response(self) -> bool: - """ - Set to True if you'd like to use stream=True option in http requester - """ + """Set to True if you'd like to use stream=True option in http requester""" @abstractmethod def decode( self, response: requests.Response ) -> Generator[MutableMapping[str, Any], None, None]: - """ - Decodes a requests.Response into a Mapping[str, Any] or an array + """Decodes a requests.Response into a Mapping[str, Any] or an array :param response: the response to decode :return: Generator of Mapping describing the response """ diff --git a/airbyte_cdk/sources/declarative/decoders/json_decoder.py b/airbyte_cdk/sources/declarative/decoders/json_decoder.py index 986bbd87..917cdb15 100644 --- a/airbyte_cdk/sources/declarative/decoders/json_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/json_decoder.py @@ -1,23 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Generator, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Generator, Mapping +from typing import Any import requests -from airbyte_cdk.sources.declarative.decoders.decoder import Decoder from orjson import orjson +from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + + logger = logging.getLogger("airbyte") @dataclass class JsonDecoder(Decoder): - """ - Decoder strategy that returns the json-encoded content of a response, if any. - """ + """Decoder strategy that returns the json-encoded content of a response, if any.""" parameters: InitVar[Mapping[str, Any]] @@ -25,9 +27,7 @@ def is_stream_response(self) -> bool: return False def decode(self, response: requests.Response) -> Generator[Mapping[str, Any], None, None]: - """ - Given the response is an empty string or an emtpy list, the function will return a generator with an empty mapping. - """ + """Given the response is an empty string or an emtpy list, the function will return a generator with an empty mapping.""" try: body_json = response.json() if not isinstance(body_json, list): @@ -45,9 +45,7 @@ def decode(self, response: requests.Response) -> Generator[Mapping[str, Any], No @dataclass class IterableDecoder(Decoder): - """ - Decoder strategy that returns the string content of the response, if any. - """ + """Decoder strategy that returns the string content of the response, if any.""" parameters: InitVar[Mapping[str, Any]] @@ -61,9 +59,7 @@ def decode(self, response: requests.Response) -> Generator[Mapping[str, Any], No @dataclass class JsonlDecoder(Decoder): - """ - Decoder strategy that returns the json-encoded content of the response, if any. - """ + """Decoder strategy that returns the json-encoded content of the response, if any.""" parameters: InitVar[Mapping[str, Any]] diff --git a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py index eb977712..ea7e8d32 100644 --- a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py @@ -1,11 +1,15 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import logging -from typing import Any, Generator, Mapping +from collections.abc import Generator, Mapping +from typing import Any import requests + from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + logger = logging.getLogger("airbyte") diff --git a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py index fa37607b..13180c7e 100644 --- a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py +++ b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py @@ -1,22 +1,24 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Generator, MutableMapping from dataclasses import dataclass -from typing import Any, Generator, MutableMapping +from typing import Any import requests + from airbyte_cdk.sources.declarative.decoders import Decoder + logger = logging.getLogger("airbyte") @dataclass class PaginationDecoderDecorator(Decoder): - """ - Decoder to wrap other decoders when instantiating a DefaultPaginator in order to bypass decoding if the response is streamed. - """ + """Decoder to wrap other decoders when instantiating a DefaultPaginator in order to bypass decoding if the response is streamed.""" def __init__(self, decoder: Decoder): self._decoder = decoder diff --git a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py index 6fb0477e..b6cc1253 100644 --- a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py @@ -1,23 +1,26 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Generator, Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Generator, Mapping, MutableMapping +from typing import Any from xml.parsers.expat import ExpatError import requests import xmltodict + from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + logger = logging.getLogger("airbyte") @dataclass class XmlDecoder(Decoder): - """ - XmlDecoder is a decoder strategy that parses the XML content of the resopnse, and converts it to a dict. + """XmlDecoder is a decoder strategy that parses the XML content of the resopnse, and converts it to a dict. This class handles XML attributes by prefixing them with an '@' symbol and represents XML text content by using the '#text' key if the element has attributes or the element name/tag. It does not currently support XML namespace declarations. diff --git a/airbyte_cdk/sources/declarative/exceptions.py b/airbyte_cdk/sources/declarative/exceptions.py index ca67c6a5..d7142fee 100644 --- a/airbyte_cdk/sources/declarative/exceptions.py +++ b/airbyte_cdk/sources/declarative/exceptions.py @@ -1,9 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations class ReadException(Exception): - """ - Raise when there is an error reading data from an API Source - """ + """Raise when there is an error reading data from an API Source""" diff --git a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py index 0878c31a..be45d56b 100644 --- a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, MutableMapping, Union +from typing import Any import dpath import requests + from airbyte_cdk.sources.declarative.decoders import Decoder, JsonDecoder from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -15,8 +18,7 @@ @dataclass class DpathExtractor(RecordExtractor): - """ - Record extractor that searches a decoded response over a path defined as an array of fields. + """Record extractor that searches a decoded response over a path defined as an array of fields. If the field path points to an array, that array is returned. If the field path points to an object, that object is returned wrapped as an array. @@ -52,7 +54,7 @@ class DpathExtractor(RecordExtractor): decoder (Decoder): The decoder responsible to transfom the response in a Mapping """ - field_path: List[Union[InterpolatedString, str]] + field_path: list[InterpolatedString | str] config: Config parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={})) diff --git a/airbyte_cdk/sources/declarative/extractors/http_selector.py b/airbyte_cdk/sources/declarative/extractors/http_selector.py index 905477a6..a165f033 100644 --- a/airbyte_cdk/sources/declarative/extractors/http_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/http_selector.py @@ -1,17 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any import requests + from airbyte_cdk.sources.types import Record, StreamSlice, StreamState class HttpSelector: - """ - Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering + """Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering records based on a heuristic. """ @@ -21,11 +23,10 @@ def select_records( response: requests.Response, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: - """ - Selects records from the response + """Selects records from the response :param response: The response to select the records from :param stream_state: The stream state :param records_schema: json schema of records to return diff --git a/airbyte_cdk/sources/declarative/extractors/record_extractor.py b/airbyte_cdk/sources/declarative/extractors/record_extractor.py index 5de6a84a..25c9ca4c 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/record_extractor.py @@ -1,26 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from abc import abstractmethod +from collections.abc import Iterable, Mapping from dataclasses import dataclass -from typing import Any, Iterable, Mapping +from typing import Any import requests @dataclass class RecordExtractor: - """ - Responsible for translating an HTTP response into a list of records by extracting records from the response. - """ + """Responsible for translating an HTTP response into a list of records by extracting records from the response.""" @abstractmethod def extract_records( self, response: requests.Response, ) -> Iterable[Mapping[str, Any]]: - """ - Selects records from the response + """Selects records from the response :param response: The response to extract the records from :return: List of Records extracted from the response """ diff --git a/airbyte_cdk/sources/declarative/extractors/record_filter.py b/airbyte_cdk/sources/declarative/extractors/record_filter.py index e84e229f..450ada8c 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_filter.py +++ b/airbyte_cdk/sources/declarative/extractors/record_filter.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import datetime +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.incremental import ( DatetimeBasedCursor, @@ -16,8 +19,7 @@ @dataclass class RecordFilter: - """ - Filter applied on a list of Records + """Filter applied on a list of Records config (Config): The user-provided configuration as specified by the source's spec condition (str): The string representing the predicate to filter a record. Records will be removed if evaluated to False @@ -36,8 +38,8 @@ def filter_records( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: kwargs = { "stream_state": stream_state, @@ -51,8 +53,7 @@ def filter_records( class ClientSideIncrementalRecordFilterDecorator(RecordFilter): - """ - Applies a filter to a list of records to exclude those that are older than the stream_state/start_date. + """Applies a filter to a list of records to exclude those that are older than the stream_state/start_date. :param DatetimeBasedCursor date_time_based_cursor: Cursor used to extract datetime values :param PerPartitionCursor per_partition_cursor: Optional Cursor used for mapping cursor value in nested stream_state @@ -61,7 +62,7 @@ class ClientSideIncrementalRecordFilterDecorator(RecordFilter): def __init__( self, date_time_based_cursor: DatetimeBasedCursor, - substream_cursor: Optional[Union[PerPartitionWithGlobalCursor, GlobalSubstreamCursor]], + substream_cursor: PerPartitionWithGlobalCursor | GlobalSubstreamCursor | None, **kwargs: Any, ): super().__init__(**kwargs) @@ -86,8 +87,8 @@ def filter_records( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: state_value = self._get_state_value( stream_state, stream_slice or StreamSlice(partition={}, cursor_slice={}) @@ -109,11 +110,8 @@ def filter_records( ) yield from records - def _get_state_value( - self, stream_state: StreamState, stream_slice: StreamSlice - ) -> Optional[str]: - """ - Return cursor_value or None in case it was not found. + def _get_state_value(self, stream_state: StreamState, stream_slice: StreamSlice) -> str | None: + """Return cursor_value or None in case it was not found. Cursor_value may be empty if: 1. It is an initial sync => no stream_state exist at all. 2. In Parent-child stream, and we already make initial sync, so stream_state is present. @@ -127,9 +125,8 @@ def _get_state_value( return state.get(self._cursor_field) if state else None - def _get_filter_date(self, state_value: Optional[str]) -> datetime.datetime: + def _get_filter_date(self, state_value: str | None) -> datetime.datetime: start_date_parsed = self._start_date_from_config if state_value: return max(start_date_parsed, self._date_time_based_cursor.parse_date(state_value)) - else: - return start_date_parsed + return start_date_parsed diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py index caaa4be2..9b88b6e9 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any import requests + from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter @@ -14,6 +17,7 @@ from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + SCHEMA_TRANSFORMER_TYPE_MAPPING = { SchemaNormalization.None_: TransformConfig.NoTransform, SchemaNormalization.Default: TransformConfig.DefaultSchemaNormalization, @@ -22,8 +26,7 @@ @dataclass class RecordSelector(HttpSelector): - """ - Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering + """Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering records based on a heuristic. Attributes: @@ -37,8 +40,8 @@ class RecordSelector(HttpSelector): config: Config parameters: InitVar[Mapping[str, Any]] schema_normalization: TypeTransformer - record_filter: Optional[RecordFilter] = None - transformations: List[RecordTransformation] = field(default_factory=lambda: []) + record_filter: RecordFilter | None = None + transformations: list[RecordTransformation] = field(default_factory=list) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._parameters = parameters @@ -48,11 +51,10 @@ def select_records( response: requests.Response, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: - """ - Selects records from the response + """Selects records from the response :param response: The response to select the records from :param stream_state: The stream state :param records_schema: json schema of records to return @@ -70,11 +72,10 @@ def filter_and_transform( all_data: Iterable[Mapping[str, Any]], stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: - """ - There is an issue with the selector as of 2024-08-30: it does technology-agnostic processing like filtering, transformation and + """There is an issue with the selector as of 2024-08-30: it does technology-agnostic processing like filtering, transformation and normalization with an API that is technology-specific (as requests.Response is only for HTTP communication using the requests library). @@ -88,7 +89,7 @@ def filter_and_transform( yield Record(data, stream_slice) def _normalize_by_schema( - self, records: Iterable[Mapping[str, Any]], schema: Optional[Mapping[str, Any]] + self, records: Iterable[Mapping[str, Any]], schema: Mapping[str, Any] | None ) -> Iterable[Mapping[str, Any]]: if schema: # record has type Mapping[str, Any], but dict[str, Any] expected @@ -103,8 +104,8 @@ def _filter( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> Iterable[Mapping[str, Any]]: if self.record_filter: yield from self.record_filter.filter_records( @@ -120,7 +121,7 @@ def _transform( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[Mapping[str, Any]]: for record in records: for transformation in self.transformations: diff --git a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py index 8be2f6b6..fb12a53d 100644 --- a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py @@ -1,26 +1,30 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import os import uuid import zlib +from collections.abc import Iterable, Mapping from contextlib import closing -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple +from typing import Any import pandas as pd import requests -from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor from numpy import nan +from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor + + EMPTY_STR: str = "" DEFAULT_ENCODING: str = "utf-8" DOWNLOAD_CHUNK_SIZE: int = 1024 * 10 class ResponseToFileExtractor(RecordExtractor): - """ - This class is used when having very big HTTP responses (usually streamed) which would require too much memory so we use disk space as + """This class is used when having very big HTTP responses (usually streamed) which would require too much memory so we use disk space as a tradeoff. Eventually, we want to support multiple file type by re-using the file based CDK parsers if possible. However, the lift is too high for @@ -30,17 +34,16 @@ class ResponseToFileExtractor(RecordExtractor): def __init__(self) -> None: self.logger = logging.getLogger("airbyte") - def _get_response_encoding(self, headers: Dict[str, Any]) -> str: - """ - Get the encoding of the response based on the provided headers. This method is heavily inspired by the requests library + def _get_response_encoding(self, headers: dict[str, Any]) -> str: + """Get the encoding of the response based on the provided headers. This method is heavily inspired by the requests library implementation. Args: headers (Dict[str, Any]): The headers of the response. + Returns: str: The encoding of the response. """ - content_type = headers.get("content-type") if not content_type: @@ -54,18 +57,17 @@ def _get_response_encoding(self, headers: Dict[str, Any]) -> str: return DEFAULT_ENCODING def _filter_null_bytes(self, b: bytes) -> bytes: - """ - Filter out null bytes from a bytes object. + """Filter out null bytes from a bytes object. Args: b (bytes): The input bytes object. + Returns: bytes: The filtered bytes object with null bytes removed. Referenced Issue: https://github.com/airbytehq/airbyte/issues/8300 """ - res = b.replace(b"\x00", b"") if len(res) < len(b): self.logger.warning( @@ -73,9 +75,8 @@ def _filter_null_bytes(self, b: bytes) -> bytes: ) return res - def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: - """ - Saves the binary data from the given response to a temporary file and returns the filepath and response encoding. + def _save_to_file(self, response: requests.Response) -> tuple[str, str]: + """Saves the binary data from the given response to a temporary file and returns the filepath and response encoding. Args: response (Optional[requests.Response]): The response object containing the binary data. Defaults to None. @@ -107,16 +108,14 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: # check the file exists if os.path.isfile(tmp_file): return tmp_file, response_encoding - else: - raise ValueError( - f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist." - ) + raise ValueError( + f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist." + ) def _read_with_chunks( self, path: str, file_encoding: str, chunk_size: int = 100 ) -> Iterable[Mapping[str, Any]]: - """ - Reads data from a file in chunks and yields each row as a dictionary. + """Reads data from a file in chunks and yields each row as a dictionary. Args: path (str): The path to the file to be read. @@ -129,9 +128,8 @@ def _read_with_chunks( Raises: ValueError: If an IO/Error occurs while reading the temporary data. """ - try: - with open(path, "r", encoding=file_encoding) as data: + with open(path, encoding=file_encoding) as data: chunks = pd.read_csv( data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object ) @@ -142,17 +140,16 @@ def _read_with_chunks( except pd.errors.EmptyDataError as e: self.logger.info(f"Empty data received. {e}") yield from [] - except IOError as ioe: + except OSError as ioe: raise ValueError(f"The IO/Error occured while reading tmp data. Called: {path}", ioe) finally: # remove binary tmp file, after data is read os.remove(path) def extract_records( - self, response: Optional[requests.Response] = None + self, response: requests.Response | None = None ) -> Iterable[Mapping[str, Any]]: - """ - Extracts records from the given response by: + """Extracts records from the given response by: 1) Saving the result to a tmp file 2) Reading from saved file by chunks to avoid OOM diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index 3977623d..f7e69095 100644 --- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -1,11 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime +from collections.abc import Callable, Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field from datetime import timedelta -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any + +from isodate import Duration, duration_isoformat, parse_duration from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser @@ -19,13 +23,11 @@ ) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState -from isodate import Duration, duration_isoformat, parse_duration @dataclass class DatetimeBasedCursor(DeclarativeCursor): - """ - Slices the stream over a datetime range and create a state with format {: } + """Slices the stream over a datetime range and create a state with format {: } Given a start time, end time, a step function, and an optional lookback window, the stream slicer will partition the date range from start time - lookback window to end time. @@ -51,28 +53,28 @@ class DatetimeBasedCursor(DeclarativeCursor): lookback_window (Optional[InterpolatedString]): how many days before start_datetime to read data for (ISO8601 duration) """ - start_datetime: Union[MinMaxDatetime, str] - cursor_field: Union[InterpolatedString, str] + start_datetime: MinMaxDatetime | str + cursor_field: InterpolatedString | str datetime_format: str config: Config parameters: InitVar[Mapping[str, Any]] - _highest_observed_cursor_field_value: Optional[str] = field( + _highest_observed_cursor_field_value: str | None = field( repr=False, default=None ) # tracks the latest observed datetime, which may not be safe to emit in the case of out-of-order records - _cursor: Optional[str] = field( + _cursor: str | None = field( repr=False, default=None ) # tracks the latest observed datetime that is appropriate to emit as stream state - end_datetime: Optional[Union[MinMaxDatetime, str]] = None - step: Optional[Union[InterpolatedString, str]] = None - cursor_granularity: Optional[str] = None - start_time_option: Optional[RequestOption] = None - end_time_option: Optional[RequestOption] = None - partition_field_start: Optional[str] = None - partition_field_end: Optional[str] = None - lookback_window: Optional[Union[InterpolatedString, str]] = None - message_repository: Optional[MessageRepository] = None - is_compare_strictly: Optional[bool] = False - cursor_datetime_formats: List[str] = field(default_factory=lambda: []) + end_datetime: MinMaxDatetime | str | None = None + step: InterpolatedString | str | None = None + cursor_granularity: str | None = None + start_time_option: RequestOption | None = None + end_time_option: RequestOption | None = None + partition_field_start: str | None = None + partition_field_end: str | None = None + lookback_window: InterpolatedString | str | None = None + message_repository: MessageRepository | None = None + is_compare_strictly: bool | None = False + cursor_datetime_formats: list[str] = field(default_factory=list) def __post_init__(self, parameters: Mapping[str, Any]) -> None: if (self.step and not self.cursor_granularity) or ( @@ -125,8 +127,7 @@ def get_stream_state(self) -> StreamState: return {self.cursor_field.eval(self.config): self._cursor} if self._cursor else {} # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ def set_initial_state(self, stream_state: StreamState) -> None: - """ - Cursors are not initialized with their state. As state is needed in order to function properly, this method should be called + """Cursors are not initialized with their state. As state is needed in order to function properly, this method should be called before calling anything else :param stream_state: The state of the stream as returned by get_stream_state @@ -136,8 +137,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: ) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ def observe(self, stream_slice: StreamSlice, record: Record) -> None: - """ - Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read. + """Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read. :param stream_slice: The current slice, which may or may not contain the most recently observed record :param record: the most recently-read record, which the cursor can use to update the stream state. Outwardly-visible changes to the @@ -187,8 +187,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: ) def stream_slices(self) -> Iterable[StreamSlice]: - """ - Partition the daterange into slices of size = step. + """Partition the daterange into slices of size = step. The start of the window is the minimum datetime between start_datetime - lookback_window and the stream_state's datetime The end of the window is the minimum datetime between the start of the window and end_datetime. @@ -199,7 +198,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: start_datetime = self._calculate_earliest_possible_value(self.select_best_end_datetime()) return self._partition_daterange(start_datetime, end_datetime, self._step) - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # Datetime based cursors operate over slices made up of datetime ranges. Stream state is based on the progress # through each slice and does not belong to a specific slice. We just return stream state as it is. return self.get_stream_state() @@ -224,8 +223,7 @@ def _calculate_earliest_possible_value( return max(earliest_possible_start_datetime, cursor_datetime) def select_best_end_datetime(self) -> datetime.datetime: - """ - Returns the optimal end datetime. + """Returns the optimal end datetime. This method compares the current datetime with a pre-configured end datetime and returns the earlier of the two. If no pre-configured end datetime is set, the current datetime is returned. @@ -251,8 +249,8 @@ def _partition_daterange( self, start: datetime.datetime, end: datetime.datetime, - step: Union[datetime.timedelta, Duration], - ) -> List[StreamSlice]: + step: datetime.timedelta | Duration, + ) -> list[StreamSlice]: start_field = self._partition_field_start.eval(self.config) end_field = self._partition_field_end.eval(self.config) dates = [] @@ -280,8 +278,7 @@ def _is_within_date_range(self, start: datetime.datetime, end: datetime.datetime def _evaluate_next_start_date_safely( self, start: datetime.datetime, step: datetime.timedelta ) -> datetime.datetime: - """ - Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date + """Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code would have broken anyway. """ @@ -308,10 +305,8 @@ def parse_date(self, date: str) -> datetime.datetime: raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}") @classmethod - def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]: - """ - :return Parses an ISO 8601 durations into datetime.timedelta or Duration objects. - """ + def _parse_timedelta(cls, time_str: str | None) -> datetime.timedelta | Duration: + """:return Parses an ISO 8601 durations into datetime.timedelta or Duration objects.""" if not time_str: return datetime.timedelta(0) return parse_duration(time_str) @@ -319,36 +314,36 @@ def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.header, stream_slice) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, stream_slice) @@ -357,7 +352,7 @@ def request_kwargs(self) -> Mapping[str, Any]: return {} def _get_request_options( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: @@ -392,8 +387,8 @@ def should_be_synced(self, record: Record) -> bool: def _is_within_daterange_boundaries( self, record: Record, - start_datetime_boundary: Union[datetime.datetime, str], - end_datetime_boundary: Union[datetime.datetime, str], + start_datetime_boundary: datetime.datetime | str, + end_datetime_boundary: datetime.datetime | str, ) -> bool: cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ record_cursor_value = record.get(cursor_field) @@ -426,14 +421,12 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: second_cursor_value = second.get(cursor_field) if first_cursor_value and second_cursor_value: return self.parse_date(first_cursor_value) >= self.parse_date(second_cursor_value) - elif first_cursor_value: + if first_cursor_value: return True - else: - return False + return False def set_runtime_lookback_window(self, lookback_window_in_seconds: int) -> None: - """ - Updates the lookback window based on a given number of seconds if the new duration + """Updates the lookback window based on a given number of seconds if the new duration is greater than the currently configured lookback window. :param lookback_window_in_seconds: The lookback duration in seconds to potentially update to. diff --git a/airbyte_cdk/sources/declarative/incremental/declarative_cursor.py b/airbyte_cdk/sources/declarative/incremental/declarative_cursor.py index adb64d11..ba3ec12f 100644 --- a/airbyte_cdk/sources/declarative/incremental/declarative_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/declarative_cursor.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import ABC @@ -7,7 +8,6 @@ class DeclarativeCursor(Cursor, StreamSlicer, ABC): - """ - DeclarativeCursors are components that allow for checkpointing syncs. In addition to managing the fetching and updating of + """DeclarativeCursors are components that allow for checkpointing syncs. In addition to managing the fetching and updating of state, declarative cursors also manage stream slicing and injecting slice values into outbound requests. """ diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index b912eb9a..62590cee 100644 --- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -1,24 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import threading import time -from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union +from collections.abc import Callable, Iterable, Mapping +from typing import Any, TypeVar from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record, StreamSlice, StreamState + T = TypeVar("T") def iterate_with_last_flag_and_state( - generator: Iterable[T], get_stream_state_func: Callable[[], Optional[Mapping[str, StreamState]]] + generator: Iterable[T], get_stream_state_func: Callable[[], Mapping[str, StreamState] | None] ) -> Iterable[tuple[T, bool, Any]]: - """ - Iterates over the given generator, yielding tuples containing the element, a flag + """Iterates over the given generator, yielding tuples containing the element, a flag indicating whether it's the last element in the generator, and the result of `get_stream_state_func` applied to the element. @@ -30,7 +32,6 @@ def iterate_with_last_flag_and_state( Returns: An iterator that yields tuples of the form (element, is_last, state). """ - iterator = iter(generator) try: @@ -48,12 +49,10 @@ def iterate_with_last_flag_and_state( class Timer: - """ - A simple timer class that measures elapsed time in seconds using a high-resolution performance counter. - """ + """A simple timer class that measures elapsed time in seconds using a high-resolution performance counter.""" def __init__(self) -> None: - self._start: Optional[int] = None + self._start: int | None = None def start(self) -> None: self._start = time.perf_counter_ns() @@ -61,13 +60,11 @@ def start(self) -> None: def finish(self) -> int: if self._start: return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__() - else: - raise RuntimeError("Global substream cursor timer not started") + raise RuntimeError("Global substream cursor timer not started") class GlobalSubstreamCursor(DeclarativeCursor): - """ - The GlobalSubstreamCursor is designed to track the state of substreams using a single global cursor. + """The GlobalSubstreamCursor is designed to track the state of substreams using a single global cursor. This class is beneficial for streams with many partitions, as it allows the state to be managed globally instead of per partition, simplifying state management and reducing the size of state messages. @@ -88,17 +85,16 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti 0 ) # Start with 0, indicating no slices being tracked self._all_slices_yielded = False - self._lookback_window: Optional[int] = None - self._current_partition: Optional[Mapping[str, Any]] = None + self._lookback_window: int | None = None + self._current_partition: Mapping[str, Any] | None = None self._last_slice: bool = False - self._parent_state: Optional[Mapping[str, Any]] = None + self._parent_state: Mapping[str, Any] | None = None def start_slices_generation(self) -> None: self._timer.start() def stream_slices(self) -> Iterable[StreamSlice]: - """ - Generates stream slices, ensuring the last slice is properly flagged and processed. + """Generates stream slices, ensuring the last slice is properly flagged and processed. This method creates a sequence of stream slices by iterating over partitions and cursor slices. It holds onto one slice in memory to set `_all_slices_yielded` to `True` before yielding the @@ -135,8 +131,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str yield from slice_generator def register_slice(self, last: bool) -> None: - """ - Tracks the processing of a stream slice. + """Tracks the processing of a stream slice. Releases the semaphore for each slice. If it's the last slice (`last=True`), sets `_all_slices_yielded` to `True` to indicate no more slices will be processed. @@ -149,8 +144,7 @@ def register_slice(self, last: bool) -> None: self._all_slices_yielded = True def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the initial state for the cursors. + """Set the initial state for the cursors. This method initializes the state for the global cursor using the provided stream state. @@ -189,8 +183,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: self._partition_router.set_initial_state(stream_state) def _inject_lookback_into_stream_cursor(self, lookback_window: int) -> None: - """ - Modifies the stream cursor's lookback window based on the duration of the previous sync. + """Modifies the stream cursor's lookback window based on the duration of the previous sync. This adjustment ensures the cursor is set to the minimal lookback window necessary for avoiding missing data. @@ -214,8 +207,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: ) def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: - """ - Close the current stream slice. + """Close the current stream slice. This method is called when a stream slice is completed. For the global parent cursor, we close the child cursor only after reading all slices. This ensures that we do not miss any child records from a later parent record @@ -244,16 +236,16 @@ def get_stream_state(self) -> StreamState: return state - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # stream_slice is ignored as cursor is global return self._stream_cursor.get_stream_state() def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_params( # type: ignore # this always returns a mapping @@ -265,15 +257,14 @@ def get_request_params( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request params") + raise ValueError("A partition needs to be provided in order to get request params") def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping @@ -285,16 +276,15 @@ def get_request_headers( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request headers") + raise ValueError("A partition needs to be provided in order to get request headers") def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: if stream_slice: return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping stream_state=stream_state, @@ -305,15 +295,14 @@ def get_request_body_data( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body data") + raise ValueError("A partition needs to be provided in order to get request body data") def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping @@ -325,8 +314,7 @@ def get_request_body_json( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body json") + raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: return self._stream_cursor.should_be_synced(self._convert_record_to_cursor_record(record)) diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index a6449d81..fcc8287d 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from collections import OrderedDict -from typing import Any, Callable, Iterable, Mapping, Optional, Union +from collections.abc import Callable, Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -13,6 +15,7 @@ ) from airbyte_cdk.sources.types import Record, StreamSlice, StreamState + logger = logging.getLogger("airbyte") @@ -25,8 +28,7 @@ def create(self) -> DeclarativeCursor: class PerPartitionCursor(DeclarativeCursor): - """ - Manages state per partition when a stream has many partitions, to prevent data loss or duplication. + """Manages state per partition when a stream has many partitions, to prevent data loss or duplication. **Partition Limitation and Limit Reached Logic** @@ -69,11 +71,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition)) if not cursor: - partition_state = ( - self._state_to_migrate_from - if self._state_to_migrate_from - else self._NO_CURSOR_STATE - ) + partition_state = self._state_to_migrate_from or self._NO_CURSOR_STATE cursor = self._create_cursor(partition_state) self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor @@ -83,9 +81,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str ) def _ensure_partition_limit(self) -> None: - """ - Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped. - """ + """Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.""" while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: self._over_limit += 1 oldest_partition = self._cursor_per_partition.popitem(last=False)[ @@ -99,8 +95,7 @@ def limit_reached(self) -> bool: return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the initial state for the cursors. + """Set the initial state for the cursors. This method initializes the state for each partition cursor using the provided stream state. If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state. @@ -161,7 +156,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: ) except KeyError as exception: raise ValueError( - f"Partition {str(exception)} could not be found in current state based on the record. This is unexpected because " + f"Partition {exception!s} could not be found in current state based on the record. This is unexpected because " f"we should only update state for partitions that were emitted during `stream_slices`" ) @@ -183,7 +178,7 @@ def get_stream_state(self) -> StreamState: state["parent_state"] = parent_state return state - def _get_state_for_partition(self, partition: Mapping[str, Any]) -> Optional[StreamState]: + def _get_state_for_partition(self, partition: Mapping[str, Any]) -> StreamState | None: cursor = self._cursor_per_partition.get(self._to_partition_key(partition)) if cursor: return cursor.get_stream_state() @@ -200,7 +195,7 @@ def _to_partition_key(self, partition: Mapping[str, Any]) -> str: def _to_dict(self, partition_key: str) -> Mapping[str, Any]: return self._partition_serializer.to_partition(partition_key) - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") @@ -217,9 +212,9 @@ def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_params( # type: ignore # this always returns a mapping @@ -233,15 +228,14 @@ def get_request_params( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request params") + raise ValueError("A partition needs to be provided in order to get request params") def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping @@ -255,16 +249,15 @@ def get_request_headers( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request headers") + raise ValueError("A partition needs to be provided in order to get request headers") def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: if stream_slice: return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping stream_state=stream_state, @@ -277,15 +270,14 @@ def get_request_body_data( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body data") + raise ValueError("A partition needs to be provided in order to get request body data") def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping @@ -299,8 +291,7 @@ def get_request_body_json( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body json") + raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: return self._get_cursor(record).should_be_synced( diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py index 346810a1..7797c5af 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py @@ -1,7 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Iterable, Mapping, MutableMapping, Optional, Union +from __future__ import annotations + +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor @@ -18,8 +21,7 @@ class PerPartitionWithGlobalCursor(DeclarativeCursor): - """ - Manages state for streams with multiple partitions, with an optional fallback to a global cursor when specific conditions are met. + """Manages state for streams with multiple partitions, with an optional fallback to a global cursor when specific conditions are met. This cursor handles partitioned streams by maintaining individual state per partition using `PerPartitionCursor`. If the number of partitions exceeds a defined limit, it switches to a global cursor (`GlobalSubstreamCursor`) to manage state more efficiently. @@ -76,11 +78,11 @@ def __init__( self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router) self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router) self._use_global_cursor = False - self._current_partition: Optional[Mapping[str, Any]] = None + self._current_partition: Mapping[str, Any] | None = None self._last_slice: bool = False - self._parent_state: Optional[Mapping[str, Any]] = None + self._parent_state: Mapping[str, Any] | None = None - def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]: + def _get_active_cursor(self) -> PerPartitionCursor | GlobalSubstreamCursor: return self._global_cursor if self._use_global_cursor else self._per_partition_cursor def stream_slices(self) -> Iterable[StreamSlice]: @@ -101,9 +103,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: self._parent_state = self._partition_router.get_stream_state() def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the initial state for the cursors. - """ + """Set the initial state for the cursors.""" self._use_global_cursor = stream_state.get("use_global_cursor", False) self._parent_state = stream_state.get("parent_state", {}) @@ -138,15 +138,15 @@ def get_stream_state(self) -> StreamState: return final_state - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: return self._get_active_cursor().select_state(stream_slice) def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_params( stream_state=stream_state, @@ -157,9 +157,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_headers( stream_state=stream_state, @@ -170,10 +170,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._get_active_cursor().get_request_body_data( stream_state=stream_state, stream_slice=stream_slice, @@ -183,9 +183,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_body_json( stream_state=stream_state, diff --git a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py index a0b4665f..c437a7be 100644 --- a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py @@ -1,7 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.incremental import DeclarativeCursor from airbyte_cdk.sources.declarative.types import Record, StreamSlice, StreamState @@ -22,9 +24,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: self._cursor = stream_state def observe(self, stream_slice: StreamSlice, record: Record) -> None: - """ - Resumable full refresh manages state using a page number so it does not need to update state by observing incoming records. - """ + """Resumable full refresh manages state using a page number so it does not need to update state by observing incoming records.""" pass def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: @@ -36,25 +36,21 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: self._cursor = stream_slice.cursor_slice def should_be_synced(self, record: Record) -> bool: - """ - Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages + """Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: - """ - RFR record don't have ordering to be compared between one another. - """ + """RFR record don't have ordering to be compared between one another.""" return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # A top-level RFR cursor only manages the state of a single partition return self._cursor def stream_slices(self) -> Iterable[StreamSlice]: - """ - Resumable full refresh cursors only return a single slice and can't perform partitioning because iteration is done per-page + """Resumable full refresh cursors only return a single slice and can't perform partitioning because iteration is done per-page along an unbounded set. """ yield from [StreamSlice(cursor_slice=self._cursor, partition={})] @@ -65,44 +61,43 @@ def stream_slices(self) -> Iterable[StreamSlice]: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} @dataclass class ChildPartitionResumableFullRefreshCursor(ResumableFullRefreshCursor): - """ - The Sub-stream Resumable Cursor for Full-Refresh substreams. + """The Sub-stream Resumable Cursor for Full-Refresh substreams. Follows the parent type `ResumableFullRefreshCursor` with a small override, to provide the ability to close the substream's slice once it has finished processing. @@ -110,8 +105,7 @@ class ChildPartitionResumableFullRefreshCursor(ResumableFullRefreshCursor): """ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: - """ - Once the current slice has finished syncing: + """Once the current slice has finished syncing: - paginator returns None - no more slices to process diff --git a/airbyte_cdk/sources/declarative/interpolation/filters.py b/airbyte_cdk/sources/declarative/interpolation/filters.py index 52d76cab..c9ca75a8 100644 --- a/airbyte_cdk/sources/declarative/interpolation/filters.py +++ b/airbyte_cdk/sources/declarative/interpolation/filters.py @@ -1,16 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import base64 import hashlib import json import re -from typing import Any, Optional +from typing import Any -def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str: - """ - Implementation of a custom Jinja2 hash filter +def hash(value: Any, hash_type: str = "md5", salt: str | None = None) -> str: + """Implementation of a custom Jinja2 hash filter Hash type defaults to 'md5' if one is not specified. If you are using this has function for GDPR compliance, then @@ -49,14 +50,13 @@ def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str: hash_obj.update(str(salt).encode("utf-8")) computed_hash: str = hash_obj.hexdigest() else: - raise AttributeError("No hashing function named {hname}".format(hname=hash_type)) + raise AttributeError(f"No hashing function named {hash_type}") return computed_hash def base64encode(value: str) -> str: - """ - Implementation of a custom Jinja2 base64encode filter + """Implementation of a custom Jinja2 base64encode filter For example: @@ -69,13 +69,11 @@ def base64encode(value: str) -> str: :param value: value to be encoded in base64 :return: base64 encoded string """ - return base64.b64encode(value.encode("utf-8")).decode() def base64decode(value: str) -> str: - """ - Implementation of a custom Jinja2 base64decode filter + """Implementation of a custom Jinja2 base64decode filter For example: @@ -88,13 +86,11 @@ def base64decode(value: str) -> str: :param value: value to be decoded from base64 :return: base64 decoded string """ - return base64.b64decode(value.encode("utf-8")).decode() def string(value: Any) -> str: - """ - Converts the input value to a string. + """Converts the input value to a string. If the value is already a string, it is returned as is. Otherwise, the value is interpreted as a json object and wrapped in triple-quotes so it's evalued as a string by the JinjaInterpolation :param value: the value to convert to a string @@ -107,9 +103,7 @@ def string(value: Any) -> str: def regex_search(value: str, regex: str) -> str: - """ - Match a regular expression against a string and return the first match group if it exists. - """ + """Match a regular expression against a string and return the first match group if it exists.""" match = re.search(regex, value) if match and len(match.groups()) > 0: return match.group(1) diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py index 78569b35..8c3b0509 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Final, List, Mapping +from typing import Any, Final from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config -FALSE_VALUES: Final[List[Any]] = [ + +FALSE_VALUES: Final[list[Any]] = [ "False", "false", "{}", @@ -43,8 +46,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._parameters = parameters def eval(self, config: Config, **additional_parameters: Any) -> bool: - """ - Interpolates the predicate condition string using the config and other optional arguments passed as parameter. + """Interpolates the predicate condition string using the config and other optional arguments passed as parameter. :param config: The user-provided configuration as specified by the source's spec :param additional_parameters: Optional parameters used for interpolation @@ -52,15 +54,14 @@ def eval(self, config: Config, **additional_parameters: Any) -> bool: """ if isinstance(self.condition, bool): return self.condition - else: - evaluated = self._interpolation.eval( - self.condition, - config, - self._default, - parameters=self._parameters, - **additional_parameters, - ) - if evaluated in FALSE_VALUES: - return False - # The presence of a value is generally regarded as truthy, so we treat it as such - return True + evaluated = self._interpolation.eval( + self.condition, + config, + self._default, + parameters=self._parameters, + **additional_parameters, + ) + if evaluated in FALSE_VALUES: + return False + # The presence of a value is generally regarded as truthy, so we treat it as such + return True diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py index 11b2dac9..d011485f 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations - +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config @@ -12,8 +13,7 @@ @dataclass class InterpolatedMapping: - """ - Wrapper around a Mapping[str, str] where both the keys and values are to be interpolated. + """Wrapper around a Mapping[str, str] where both the keys and values are to be interpolated. Attributes: mapping (Mapping[str, str]): to be evaluated @@ -22,13 +22,12 @@ class InterpolatedMapping: mapping: Mapping[str, str] parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None: + def __post_init__(self, parameters: Mapping[str, Any] | None) -> None: self._interpolation = JinjaInterpolation() self._parameters = parameters - def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]: - """ - Wrapper around a Mapping[str, str] that allows for both keys and values to be interpolated. + def eval(self, config: Config, **additional_parameters: Any) -> dict[str, Any]: + """Wrapper around a Mapping[str, str] that allows for both keys and values to be interpolated. :param config: The user-provided configuration as specified by the source's spec :param additional_parameters: Optional parameters used for interpolation @@ -52,5 +51,4 @@ def _eval(self, value: str, config: Config, **kwargs: Any) -> Any: # We only want to interpolate them if they are strings if isinstance(value, str): return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs) - else: - return value + return value diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py index 82454919..d37b96e6 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py @@ -1,14 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations - +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any, Union from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config + NestedMappingEntry = Union[ dict[str, "NestedMapping"], list["NestedMapping"], str, int, float, bool, None ] @@ -17,8 +19,7 @@ @dataclass class InterpolatedNestedMapping: - """ - Wrapper around a nested dict which can contain lists and primitive values where both the keys and values are interpolated recursively. + """Wrapper around a nested dict which can contain lists and primitive values where both the keys and values are interpolated recursively. Attributes: mapping (NestedMapping): to be evaluated @@ -27,7 +28,7 @@ class InterpolatedNestedMapping: mapping: NestedMapping parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None: + def __post_init__(self, parameters: Mapping[str, Any] | None) -> None: self._interpolation = JinjaInterpolation() self._parameters = parameters @@ -35,18 +36,17 @@ def eval(self, config: Config, **additional_parameters: Any) -> Any: return self._eval(self.mapping, config, **additional_parameters) def _eval( - self, value: Union[NestedMapping, NestedMappingEntry], config: Config, **kwargs: Any + self, value: NestedMapping | NestedMappingEntry, config: Config, **kwargs: Any ) -> Any: # Recursively interpolate dictionaries and lists if isinstance(value, str): return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs) - elif isinstance(value, dict): + if isinstance(value, dict): interpolated_dict = { self._eval(k, config, **kwargs): self._eval(v, config, **kwargs) for k, v in value.items() } return {k: v for k, v in interpolated_dict.items() if v is not None} - elif isinstance(value, list): + if isinstance(value, list): return [self._eval(v, config, **kwargs) for v in value] - else: - return value + return value diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py index 542fa806..773b00cb 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config @@ -11,8 +13,7 @@ @dataclass class InterpolatedString: - """ - Wrapper around a raw string to be interpolated with the Jinja2 templating engine + """Wrapper around a raw string to be interpolated with the Jinja2 templating engine Attributes: string (str): The string to evalute @@ -22,7 +23,7 @@ class InterpolatedString: string: str parameters: InitVar[Mapping[str, Any]] - default: Optional[str] = None + default: str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.default = self.default or self.string @@ -33,8 +34,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._is_plain_string = None def eval(self, config: Config, **kwargs: Any) -> Any: - """ - Interpolates the input string using the config and other optional arguments passed as parameter. + """Interpolates the input string using the config and other optional arguments passed as parameter. :param config: The user-provided configuration as specified by the source's spec :param kwargs: Optional parameters used for interpolation @@ -54,7 +54,7 @@ def eval(self, config: Config, **kwargs: Any) -> Any: self.string, config, self.default, parameters=self._parameters, **kwargs ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, InterpolatedString): return False return self.string == other.string and self.default == other.default @@ -62,12 +62,11 @@ def __eq__(self, other: Any) -> bool: @classmethod def create( cls, - string_or_interpolated: Union["InterpolatedString", str], + string_or_interpolated: InterpolatedString | str, *, parameters: Mapping[str, Any], - ) -> "InterpolatedString": - """ - Helper function to obtain an InterpolatedString from either a raw string or an InterpolatedString. + ) -> InterpolatedString: + """Helper function to obtain an InterpolatedString from either a raw string or an InterpolatedString. :param string_or_interpolated: either a raw string or an InterpolatedString. :param parameters: parameters propagated from parent component @@ -75,5 +74,4 @@ def create( """ if isinstance(string_or_interpolated, str): return InterpolatedString(string=string_or_interpolated, parameters=parameters) - else: - return string_or_interpolated + return string_or_interpolated diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolation.py b/airbyte_cdk/sources/declarative/interpolation/interpolation.py index 5af61905..525fb0be 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolation.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolation.py @@ -1,28 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.types import Config class Interpolation(ABC): - """ - Strategy for evaluating the interpolated value of a string at runtime using Jinja. - """ + """Strategy for evaluating the interpolated value of a string at runtime using Jinja.""" @abstractmethod def eval( self, input_str: str, config: Config, - default: Optional[str] = None, + default: str | None = None, **additional_options: Any, ) -> Any: - """ - Interpolates the input string using the config, and additional options passed as parameter. + """Interpolates the input string using the config, and additional options passed as parameter. :param input_str: The string to interpolate :param config: The user-provided configuration as specified by the source's spec diff --git a/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte_cdk/sources/declarative/interpolation/jinja.py index 553ef024..a39e0cfc 100644 --- a/airbyte_cdk/sources/declarative/interpolation/jinja.py +++ b/airbyte_cdk/sources/declarative/interpolation/jinja.py @@ -1,37 +1,38 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import ast +from collections.abc import Mapping from functools import cache -from typing import Any, Mapping, Optional, Tuple, Type +from typing import Any -from airbyte_cdk.sources.declarative.interpolation.filters import filters -from airbyte_cdk.sources.declarative.interpolation.interpolation import Interpolation -from airbyte_cdk.sources.declarative.interpolation.macros import macros -from airbyte_cdk.sources.types import Config from jinja2 import meta from jinja2.environment import Template from jinja2.exceptions import UndefinedError from jinja2.sandbox import SandboxedEnvironment +from airbyte_cdk.sources.declarative.interpolation.filters import filters +from airbyte_cdk.sources.declarative.interpolation.interpolation import Interpolation +from airbyte_cdk.sources.declarative.interpolation.macros import macros +from airbyte_cdk.sources.types import Config + class StreamPartitionAccessEnvironment(SandboxedEnvironment): - """ - Currently, source-jira is setting an attribute to StreamSlice specific to its use case which because of the PerPartitionCursor is set to + """Currently, source-jira is setting an attribute to StreamSlice specific to its use case which because of the PerPartitionCursor is set to StreamSlice._partition but not exposed through StreamSlice.partition. This is a patch to still allow source-jira to have access to this parameter """ def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool: - if attr in ["_partition"]: + if attr == "_partition": return True return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"' class JinjaInterpolation(Interpolation): - """ - Interpolation strategy using the Jinja2 template engine. + """Interpolation strategy using the Jinja2 template engine. If the input string is a raw string, the interpolated string will be the same. `eval("hello world") -> "hello world"` @@ -79,8 +80,8 @@ def eval( self, input_str: str, config: Config, - default: Optional[str] = None, - valid_types: Optional[Tuple[Type[Any]]] = None, + default: str | None = None, + valid_types: tuple[type[Any]] | None = None, **additional_parameters: Any, ) -> Any: context = {"config": config, **additional_parameters} @@ -91,7 +92,7 @@ def eval( raise ValueError( f"Found reserved keyword {alias} in interpolation context. This is unexpected and indicative of a bug in the CDK." ) - elif equivalent in context: + if equivalent in context: context[alias] = context[equivalent] try: @@ -107,7 +108,7 @@ def eval( # If result is empty or resulted in an undefined error, evaluate and return the default string return self._literal_eval(self._eval(default, context), valid_types) - def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[Any]]]) -> Any: + def _literal_eval(self, result: str | None, valid_types: tuple[type[Any]] | None) -> Any: try: evaluated = ast.literal_eval(result) # type: ignore # literal_eval is able to handle None except (ValueError, SyntaxError): @@ -116,7 +117,7 @@ def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[ return evaluated return result - def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: + def _eval(self, s: str | None, context: Mapping[str, Any]) -> str | None: try: undeclared = self._find_undeclared_variables(s) undeclared_not_in_context = {var for var in undeclared if var not in context} @@ -131,16 +132,12 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: return s @cache - def _find_undeclared_variables(self, s: Optional[str]) -> Template: - """ - Find undeclared variables and cache them - """ + def _find_undeclared_variables(self, s: str | None) -> Template: + """Find undeclared variables and cache them""" ast = self._environment.parse(s) # type: ignore # parse is able to handle None return meta.find_undeclared_variables(ast) @cache - def _compile(self, s: Optional[str]) -> Template: - """ - We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader - """ + def _compile(self, s: str | None) -> Template: + """We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader""" return self._environment.from_string(s) diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py index ce448c12..33d562cc 100644 --- a/airbyte_cdk/sources/declarative/interpolation/macros.py +++ b/airbyte_cdk/sources/declarative/interpolation/macros.py @@ -1,25 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import builtins import datetime import typing -from typing import Optional, Union import isodate import pytz from dateutil import parser from isodate import parse_duration + """ This file contains macros that can be evaluated by a `JinjaInterpolation` object """ def now_utc() -> datetime.datetime: - """ - Current local date and time in UTC timezone + """Current local date and time in UTC timezone Usage: `"{{ now_utc() }}"` @@ -28,8 +28,7 @@ def now_utc() -> datetime.datetime: def today_utc() -> datetime.date: - """ - Current date in UTC timezone + """Current date in UTC timezone Usage: `"{{ today_utc() }}"` @@ -38,8 +37,7 @@ def today_utc() -> datetime.date: def today_with_timezone(timezone: str) -> datetime.date: - """ - Current date in custom timezone + """Current date in custom timezone :param timezone: timezone expressed as IANA keys format. Example: "Pacific/Tarawa" :return: @@ -47,9 +45,8 @@ def today_with_timezone(timezone: str) -> datetime.date: return datetime.datetime.now(tz=pytz.timezone(timezone)).date() -def timestamp(dt: Union[float, str]) -> Union[int, float]: - """ - Converts a number or a string to a timestamp +def timestamp(dt: float | str) -> int | float: + """Converts a number or a string to a timestamp If dt is a number, then convert to an int If dt is a string, then parse it using dateutil.parser @@ -62,8 +59,7 @@ def timestamp(dt: Union[float, str]) -> Union[int, float]: """ if isinstance(dt, (int, float)): return int(dt) - else: - return _str_to_datetime(dt).astimezone(pytz.utc).timestamp() + return _str_to_datetime(dt).astimezone(pytz.utc).timestamp() def _str_to_datetime(s: str) -> datetime.datetime: @@ -75,8 +71,7 @@ def _str_to_datetime(s: str) -> datetime.datetime: def max(*args: typing.Any) -> typing.Any: - """ - Returns biggest object of an iterable, or two or more arguments. + """Returns biggest object of an iterable, or two or more arguments. max(iterable, *[, default=obj, key=func]) -> value max(arg1, arg2, *args, *[, key=func]) -> value @@ -95,8 +90,7 @@ def max(*args: typing.Any) -> typing.Any: def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: - """ - Returns datetime of now() + num_days + """Returns datetime of now() + num_days Usage: `"{{ day_delta(25) }}"` @@ -109,9 +103,8 @@ def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: ).strftime(format) -def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: - """ - Converts ISO8601 duration to datetime.timedelta +def duration(datestring: str) -> datetime.timedelta | isodate.Duration: + """Converts ISO8601 duration to datetime.timedelta Usage: `"{{ now_utc() - duration('P1D') }}"` @@ -120,10 +113,9 @@ def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: def format_datetime( - dt: Union[str, datetime.datetime], format: str, input_format: Optional[str] = None + dt: str | datetime.datetime, format: str, input_format: str | None = None ) -> str: - """ - Converts datetime to another format + """Converts datetime to another format Usage: `"{{ format_datetime(config.start_date, '%Y-%m-%d') }}"` diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index 05fbee7a..a24b4b06 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -1,16 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging import pkgutil import re +from collections.abc import Iterator, Mapping from copy import deepcopy from importlib import metadata -from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union +from typing import Any import yaml +from jsonschema.exceptions import ValidationError +from jsonschema.validators import validate + from airbyte_cdk.models import ( AirbyteConnectionStatus, AirbyteMessage, @@ -44,8 +49,6 @@ DebugSliceLogger, SliceLogger, ) -from jsonschema.exceptions import ValidationError -from jsonschema.validators import validate class ManifestDeclarativeSource(DeclarativeSource): @@ -56,10 +59,9 @@ def __init__( source_config: ConnectionDefinition, debug: bool = False, emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, + component_factory: ModelToComponentFactory | None = None, ): - """ - :param source_config(Mapping[str, Any]): The manifest of low-code components that describe the source connector + """:param source_config(Mapping[str, Any]): The manifest of low-code components that describe the source connector :param debug(bool): True if debug mode is enabled :param component_factory(ModelToComponentFactory): optional factory if ModelToComponentFactory's default behaviour needs to be tweaked """ @@ -77,10 +79,8 @@ def __init__( self._source_config = propagated_source_config self._debug = debug self._emit_connector_builder_messages = emit_connector_builder_messages - self._constructor = ( - component_factory - if component_factory - else ModelToComponentFactory(emit_connector_builder_messages) + self._constructor = component_factory or ModelToComponentFactory( + emit_connector_builder_messages ) self._message_repository = self._constructor.get_message_repository() self._slice_logger: SliceLogger = ( @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]: return self._source_config @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> None | MessageRepository: return self._message_repository @property @@ -110,12 +110,11 @@ def connection_checker(self) -> ConnectionChecker: ) if isinstance(check_stream, ConnectionChecker): return check_stream - else: - raise ValueError( - f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" - ) + raise ValueError( + f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" + ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: self._emit_manifest_debug_message( extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)} ) @@ -135,8 +134,8 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: @staticmethod def _initialize_cache_for_parent_streams( - stream_configs: List[Dict[str, Any]], - ) -> List[Dict[str, Any]]: + stream_configs: list[dict[str, Any]], + ) -> list[dict[str, Any]]: parent_streams = set() def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> None: @@ -170,8 +169,7 @@ def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> No return stream_configs def spec(self, logger: logging.Logger) -> ConnectorSpecification: - """ - Returns the connector specification (spec) as defined in the Airbyte Protocol. The spec is an object describing the possible + """Returns the connector specification (spec) as defined in the Airbyte Protocol. The spec is an object describing the possible configurations (e.g: username and password) which can be configured when running this connector. For low-code connectors, this will first attempt to load the spec from the manifest's spec block, otherwise it will load it from "spec.yaml" or "spec.json" in the project root. @@ -187,8 +185,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: spec["type"] = "Spec" spec_component = self._constructor.create_component(SpecModel, spec, dict()) return spec_component.generate_spec() - else: - return super().spec(logger) + return super().spec(logger) def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus: self._configure_logger_level(logger) @@ -199,22 +196,18 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: self._configure_logger_level(logger) yield from super().read(logger, config, catalog, state) def _configure_logger_level(self, logger: logging.Logger) -> None: - """ - Set the log level to logging.DEBUG if debug mode is enabled - """ + """Set the log level to logging.DEBUG if debug mode is enabled""" if self._debug: logger.setLevel(logging.DEBUG) def _validate_source(self) -> None: - """ - Validates the connector manifest against the declarative component schema - """ + """Validates the connector manifest against the declarative component schema""" try: raw_component_schema = pkgutil.get_data( "airbyte_cdk", "sources/declarative/declarative_component_schema.yaml" @@ -263,7 +256,7 @@ def _validate_source(self) -> None: f"The manifest version {manifest_version} is greater than the airbyte-cdk package version ({cdk_version}). Your " f"manifest may contain features that are not in the current CDK version." ) - elif manifest_major == 0 and manifest_minor < 29: + if manifest_major == 0 and manifest_minor < 29: raise ValidationError( f"The low-code framework was promoted to Beta in airbyte-cdk version 0.29.0 and contains many breaking changes to the " f"language. The manifest version {manifest_version} is incompatible with the airbyte-cdk package version " @@ -271,10 +264,8 @@ def _validate_source(self) -> None: ) @staticmethod - def _get_version_parts(version: str, version_type: str) -> Tuple[int, int, int]: - """ - Takes a semantic version represented as a string and splits it into a tuple of its major, minor, and patch versions. - """ + def _get_version_parts(version: str, version_type: str) -> tuple[int, int, int]: + """Takes a semantic version represented as a string and splits it into a tuple of its major, minor, and patch versions.""" version_parts = re.split(r"\.", version) if len(version_parts) != 3 or not all([part.isdigit() for part in version_parts]): raise ValidationError( @@ -282,9 +273,9 @@ def _get_version_parts(version: str, version_type: str) -> Tuple[int, int, int]: ) return tuple(int(part) for part in version_parts) # type: ignore # We already verified there were 3 parts and they are all digits - def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]: + def _stream_configs(self, manifest: Mapping[str, Any]) -> list[dict[str, Any]]: # This has a warning flag for static, but after we finish part 4 we'll replace manifest with self._source_config - stream_configs: List[Dict[str, Any]] = manifest.get("streams", []) + stream_configs: list[dict[str, Any]] = manifest.get("streams", []) for s in stream_configs: if "type" not in s: s["type"] = "DeclarativeStream" diff --git a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py index 38546168..e34a11bc 100644 --- a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py +++ b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration @@ -13,8 +15,7 @@ def _is_already_migrated(stream_state: Mapping[str, Any]) -> bool: class LegacyToPerPartitionStateMigration(StateMigration): - """ - Transforms the input state for per-partitioned streams from the legacy format to the low-code format. + """Transforms the input state for per-partitioned streams from the legacy format to the low-code format. The cursor field and partition ID fields are automatically extracted from the stream's DatetimebasedCursor and SubstreamPartitionRouter. Example input state: diff --git a/airbyte_cdk/sources/declarative/migrations/state_migration.py b/airbyte_cdk/sources/declarative/migrations/state_migration.py index 9cf7f3cf..6146073a 100644 --- a/airbyte_cdk/sources/declarative/migrations/state_migration.py +++ b/airbyte_cdk/sources/declarative/migrations/state_migration.py @@ -1,14 +1,15 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import abstractmethod -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class StateMigration: @abstractmethod def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: - """ - Check if the stream_state should be migrated + """Check if the stream_state should be migrated :param stream_state: The stream_state to potentially migrate :return: true if the state is of the expected format and should be migrated. False otherwise. @@ -16,8 +17,7 @@ def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: @abstractmethod def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: - """ - Migrate the stream_state. Assumes should_migrate(stream_state) returned True. + """Migrate the stream_state. Assumes should_migrate(stream_state) returned True. :param stream_state: The stream_state to migrate :return: The migrated stream_state diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 43848eae..a98a275e 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -4,10 +4,9 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Literal from pydantic.v1 import BaseModel, Extra, Field -from typing_extensions import Literal class AuthFlowType(Enum): @@ -23,13 +22,13 @@ class BasicHttpAuthenticator(BaseModel): examples=["{{ config['username'] }}", "{{ config['api_key'] }}"], title="Username", ) - password: Optional[str] = Field( + password: str | None = Field( "", description="The password that will be combined with the username, base64 encoded and used to make requests. Fill it in the user inputs.", examples=["{{ config['password'] }}", ""], title="Password", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class BearerAuthenticator(BaseModel): @@ -40,12 +39,12 @@ class BearerAuthenticator(BaseModel): examples=["{{ config['api_key'] }}", "{{ config['token'] }}"], title="Bearer Token", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CheckStream(BaseModel): type: Literal["CheckStream"] - stream_names: List[str] = Field( + stream_names: list[str] = Field( ..., description="Names of the streams to try reading from when running a check operation.", examples=[["users"], ["users", "contacts"]], @@ -54,31 +53,31 @@ class CheckStream(BaseModel): class ConcurrencyLevel(BaseModel): - type: Optional[Literal["ConcurrencyLevel"]] = None - default_concurrency: Union[int, str] = Field( + type: Literal["ConcurrencyLevel"] | None = None + default_concurrency: int | str = Field( ..., description="The amount of concurrency that will applied during a sync. This value can be hardcoded or user-defined in the config if different users have varying volume thresholds in the target API.", examples=[10, "{{ config['num_workers'] or 10 }}"], title="Default Concurrency", ) - max_concurrency: Optional[int] = Field( + max_concurrency: int | None = Field( None, description="The maximum level of concurrency that will be used during a sync. This becomes a required field when the default_concurrency derives from the config, because it serves as a safeguard against a user-defined threshold that is too high.", examples=[20, 100], title="Max Concurrency", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ConstantBackoffStrategy(BaseModel): type: Literal["ConstantBackoffStrategy"] - backoff_time_in_seconds: Union[float, str] = Field( + backoff_time_in_seconds: float | str = Field( ..., description="Backoff time in seconds.", examples=[30, 30.5, "{{ config['backoff_time'] }}"], title="Backoff Time", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CursorPagination(BaseModel): @@ -93,13 +92,13 @@ class CursorPagination(BaseModel): ], title="Cursor Value", ) - page_size: Optional[int] = Field( + page_size: int | None = Field( None, description="The number of records to include in each pages.", examples=[100], title="Page Size", ) - stop_condition: Optional[str] = Field( + stop_condition: str | None = Field( None, description="Template string evaluating when to stop paginating.", examples=[ @@ -108,7 +107,7 @@ class CursorPagination(BaseModel): ], title="Stop Condition", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomAuthenticator(BaseModel): @@ -122,7 +121,7 @@ class Config: examples=["source_railz.components.ShortLivedTokenAuthenticator"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomBackoffStrategy(BaseModel): @@ -136,7 +135,7 @@ class Config: examples=["source_railz.components.MyCustomBackoffStrategy"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomErrorHandler(BaseModel): @@ -150,7 +149,7 @@ class Config: examples=["source_railz.components.MyCustomErrorHandler"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomIncrementalSync(BaseModel): @@ -168,7 +167,7 @@ class Config: ..., description="The location of the value on a record that will be used as a bookmark during sync.", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomPaginationStrategy(BaseModel): @@ -182,7 +181,7 @@ class Config: examples=["source_railz.components.MyCustomPaginationStrategy"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRecordExtractor(BaseModel): @@ -196,7 +195,7 @@ class Config: examples=["source_railz.components.MyCustomRecordExtractor"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRecordFilter(BaseModel): @@ -210,7 +209,7 @@ class Config: examples=["source_railz.components.MyCustomCustomRecordFilter"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRequester(BaseModel): @@ -224,7 +223,7 @@ class Config: examples=["source_railz.components.MyCustomRecordExtractor"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRetriever(BaseModel): @@ -238,7 +237,7 @@ class Config: examples=["source_railz.components.MyCustomRetriever"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomPartitionRouter(BaseModel): @@ -252,7 +251,7 @@ class Config: examples=["source_railz.components.MyCustomPartitionRouter"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomSchemaLoader(BaseModel): @@ -266,7 +265,7 @@ class Config: examples=["source_railz.components.MyCustomSchemaLoader"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomStateMigration(BaseModel): @@ -280,7 +279,7 @@ class Config: examples=["source_railz.components.MyCustomStateMigration"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomTransformation(BaseModel): @@ -294,14 +293,14 @@ class Config: examples=["source_railz.components.MyCustomTransformation"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class LegacyToPerPartitionStateMigration(BaseModel): class Config: extra = Extra.allow - type: Optional[Literal["LegacyToPerPartitionStateMigration"]] = None + type: Literal["LegacyToPerPartitionStateMigration"] | None = None class Algorithm(Enum): @@ -325,19 +324,19 @@ class JwtHeaders(BaseModel): class Config: extra = Extra.forbid - kid: Optional[str] = Field( + kid: str | None = Field( None, description="Private key ID for user account.", examples=["{{ config['kid'] }}"], title="Key Identifier", ) - typ: Optional[str] = Field( + typ: str | None = Field( "JWT", description="The media type of the complete JWT.", examples=["JWT"], title="Type", ) - cty: Optional[str] = Field( + cty: str | None = Field( None, description="Content type of JWT header.", examples=["JWT"], @@ -349,18 +348,18 @@ class JwtPayload(BaseModel): class Config: extra = Extra.forbid - iss: Optional[str] = Field( + iss: str | None = Field( None, description="The user/principal that issued the JWT. Commonly a value unique to the user.", examples=["{{ config['iss'] }}"], title="Issuer", ) - sub: Optional[str] = Field( + sub: str | None = Field( None, description="The subject of the JWT. Commonly defined by the API.", title="Subject", ) - aud: Optional[str] = Field( + aud: str | None = Field( None, description="The recipient that the JWT is intended for. Commonly defined by the API.", examples=["appstoreconnect-v1"], @@ -375,7 +374,7 @@ class JwtAuthenticator(BaseModel): description="Secret used to sign the JSON web token.", examples=["{{ config['secret_key'] }}"], ) - base64_encode_secret_key: Optional[bool] = Field( + base64_encode_secret_key: bool | None = Field( False, description='When set to true, the secret key will be base64 encoded prior to being encoded as part of the JWT. Only set to "true" when required by the API.', ) @@ -384,79 +383,79 @@ class JwtAuthenticator(BaseModel): description="Algorithm used to sign the JSON web token.", examples=["ES256", "HS256", "RS256", "{{ config['algorithm'] }}"], ) - token_duration: Optional[int] = Field( + token_duration: int | None = Field( 1200, description="The amount of time in seconds a JWT token can be valid after being issued.", examples=[1200, 3600], title="Token Duration", ) - header_prefix: Optional[str] = Field( + header_prefix: str | None = Field( None, description="The prefix to be used within the Authentication header.", examples=["Bearer", "Basic"], title="Header Prefix", ) - jwt_headers: Optional[JwtHeaders] = Field( + jwt_headers: JwtHeaders | None = Field( None, description="JWT headers used when signing JSON web token.", title="JWT Headers", ) - additional_jwt_headers: Optional[Dict[str, Any]] = Field( + additional_jwt_headers: dict[str, Any] | None = Field( None, description="Additional headers to be included with the JWT headers object.", title="Additional JWT Headers", ) - jwt_payload: Optional[JwtPayload] = Field( + jwt_payload: JwtPayload | None = Field( None, description="JWT Payload used when signing JSON web token.", title="JWT Payload", ) - additional_jwt_payload: Optional[Dict[str, Any]] = Field( + additional_jwt_payload: dict[str, Any] | None = Field( None, description="Additional properties to be added to the JWT payload.", title="Additional JWT Payload Properties", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class RefreshTokenUpdater(BaseModel): - refresh_token_name: Optional[str] = Field( + refresh_token_name: str | None = Field( "refresh_token", description="The name of the property which contains the updated refresh token in the response from the token refresh endpoint.", examples=["refresh_token"], title="Refresh Token Property Name", ) - access_token_config_path: Optional[List[str]] = Field( + access_token_config_path: list[str] | None = Field( ["credentials", "access_token"], description="Config path to the access token. Make sure the field actually exists in the config.", examples=[["credentials", "access_token"], ["access_token"]], title="Config Path To Access Token", ) - refresh_token_config_path: Optional[List[str]] = Field( + refresh_token_config_path: list[str] | None = Field( ["credentials", "refresh_token"], description="Config path to the access token. Make sure the field actually exists in the config.", examples=[["credentials", "refresh_token"], ["refresh_token"]], title="Config Path To Refresh Token", ) - token_expiry_date_config_path: Optional[List[str]] = Field( + token_expiry_date_config_path: list[str] | None = Field( ["credentials", "token_expiry_date"], description="Config path to the expiry date. Make sure actually exists in the config.", examples=[["credentials", "token_expiry_date"]], title="Config Path To Expiry Date", ) - refresh_token_error_status_codes: Optional[List[int]] = Field( + refresh_token_error_status_codes: list[int] | None = Field( [], description="Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error", examples=[[400, 500]], title="Refresh Token Error Status Codes", ) - refresh_token_error_key: Optional[str] = Field( + refresh_token_error_key: str | None = Field( "", description="Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).", examples=["error"], title="Refresh Token Error Key", ) - refresh_token_error_values: Optional[List[str]] = Field( + refresh_token_error_values: list[str] | None = Field( [], description='List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error', examples=[["invalid_grant", "invalid_permissions"]], @@ -481,7 +480,7 @@ class OAuthAuthenticator(BaseModel): ], title="Client Secret", ) - refresh_token: Optional[str] = Field( + refresh_token: str | None = Field( None, description="Credential artifact used to get a new access token.", examples=[ @@ -496,25 +495,25 @@ class OAuthAuthenticator(BaseModel): examples=["https://connect.squareup.com/oauth2/token"], title="Token Refresh Endpoint", ) - access_token_name: Optional[str] = Field( + access_token_name: str | None = Field( "access_token", description="The name of the property which contains the access token in the response from the token refresh endpoint.", examples=["access_token"], title="Access Token Property Name", ) - expires_in_name: Optional[str] = Field( + expires_in_name: str | None = Field( "expires_in", description="The name of the property which contains the expiry date in the response from the token refresh endpoint.", examples=["expires_in"], title="Token Expiry Property Name", ) - grant_type: Optional[str] = Field( + grant_type: str | None = Field( "refresh_token", description="Specifies the OAuth2 grant type. If set to refresh_token, the refresh_token needs to be provided as well. For client_credentials, only client id and secret are required. Other grant types are not officially supported.", examples=["refresh_token", "client_credentials"], title="Grant Type", ) - refresh_request_body: Optional[Dict[str, Any]] = Field( + refresh_request_body: dict[str, Any] | None = Field( None, description="Body of the request sent to get a new access token.", examples=[ @@ -526,35 +525,35 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Request Body", ) - scopes: Optional[List[str]] = Field( + scopes: list[str] | None = Field( None, description="List of scopes that should be granted to the access token.", examples=[["crm.list.read", "crm.objects.contacts.read", "crm.schema.contacts.read"]], title="Scopes", ) - token_expiry_date: Optional[str] = Field( + token_expiry_date: str | None = Field( None, description="The access token expiry date.", examples=["2023-04-06T07:12:10.421833+00:00", 1680842386], title="Token Expiry Date", ) - token_expiry_date_format: Optional[str] = Field( + token_expiry_date_format: str | None = Field( None, description="The format of the time to expiration datetime. Provide it if the time is returned as a date-time string instead of seconds.", examples=["%Y-%m-%d %H:%M:%S.%f+00:00"], title="Token Expiry Date Format", ) - refresh_token_updater: Optional[RefreshTokenUpdater] = Field( + refresh_token_updater: RefreshTokenUpdater | None = Field( None, description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.", title="Token Updater", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DpathExtractor(BaseModel): type: Literal["DpathExtractor"] - field_path: List[str] = Field( + field_path: list[str] = Field( ..., description='List of potentially nested fields describing the full path of the field to extract. Use "*" to extract all values from an array. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/record-selector).', examples=[ @@ -565,18 +564,18 @@ class DpathExtractor(BaseModel): ], title="Field Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ExponentialBackoffStrategy(BaseModel): type: Literal["ExponentialBackoffStrategy"] - factor: Optional[Union[float, str]] = Field( + factor: float | str | None = Field( 5, description="Multiplicative constant applied on each retry.", examples=[5, 5.5, "10"], title="Factor", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenRequestBearerAuthenticator(BaseModel): @@ -604,36 +603,36 @@ class FailureType(Enum): class HttpResponseFilter(BaseModel): type: Literal["HttpResponseFilter"] - action: Optional[Action] = Field( + action: Action | None = Field( None, description="Action to execute if a response matches the filter.", examples=["SUCCESS", "FAIL", "RETRY", "IGNORE", "RATE_LIMITED"], title="Action", ) - failure_type: Optional[FailureType] = Field( + failure_type: FailureType | None = Field( None, description="Failure type of traced exception if a response matches the filter.", examples=["system_error", "config_error", "transient_error"], title="Failure Type", ) - error_message: Optional[str] = Field( + error_message: str | None = Field( None, description="Error Message to display if the response matches the filter.", title="Error Message", ) - error_message_contains: Optional[str] = Field( + error_message_contains: str | None = Field( None, description="Match the response if its error message contains the substring.", example=["This API operation is not enabled for this site"], title="Error Message Substring", ) - http_codes: Optional[List[int]] = Field( + http_codes: list[int] | None = Field( None, description="Match the response if its HTTP code is included in this list.", examples=[[420, 429], [500]], title="HTTP Codes", ) - predicate: Optional[str] = Field( + predicate: str | None = Field( None, description="Match the response if the predicate evaluates to true.", examples=[ @@ -642,12 +641,12 @@ class HttpResponseFilter(BaseModel): ], title="Predicate", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class InlineSchemaLoader(BaseModel): type: Literal["InlineSchemaLoader"] - schema_: Optional[Dict[str, Any]] = Field( + schema_: dict[str, Any] | None = Field( None, alias="schema", description='Describes a streams\' schema. Refer to the Data Types documentation for more details on which types are valid.', @@ -657,13 +656,13 @@ class InlineSchemaLoader(BaseModel): class JsonFileSchemaLoader(BaseModel): type: Literal["JsonFileSchemaLoader"] - file_path: Optional[str] = Field( + file_path: str | None = Field( None, description="Path to the JSON file defining the schema. The path is relative to the connector module's root.", example=["./schemas/users.json"], title="File Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class JsonDecoder(BaseModel): @@ -676,7 +675,7 @@ class JsonlDecoder(BaseModel): class KeysToLower(BaseModel): type: Literal["KeysToLower"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class IterableDecoder(BaseModel): @@ -695,30 +694,30 @@ class MinMaxDatetime(BaseModel): examples=["2021-01-01", "2021-01-01T00:00:00Z", "{{ config['start_time'] }}"], title="Datetime", ) - datetime_format: Optional[str] = Field( + datetime_format: str | None = Field( "", description='Format of the datetime value. Defaults to "%Y-%m-%dT%H:%M:%S.%f%z" if left empty. Use placeholders starting with "%" to describe the format the API is using. The following placeholders are available:\n * **%s**: Epoch unix timestamp - `1686218963`\n * **%s_as_float**: Epoch unix timestamp in seconds as float with microsecond precision - `1686218963.123456`\n * **%ms**: Epoch unix timestamp - `1686218963123`\n * **%a**: Weekday (abbreviated) - `Sun`\n * **%A**: Weekday (full) - `Sunday`\n * **%w**: Weekday (decimal) - `0` (Sunday), `6` (Saturday)\n * **%d**: Day of the month (zero-padded) - `01`, `02`, ..., `31`\n * **%b**: Month (abbreviated) - `Jan`\n * **%B**: Month (full) - `January`\n * **%m**: Month (zero-padded) - `01`, `02`, ..., `12`\n * **%y**: Year (without century, zero-padded) - `00`, `01`, ..., `99`\n * **%Y**: Year (with century) - `0001`, `0002`, ..., `9999`\n * **%H**: Hour (24-hour, zero-padded) - `00`, `01`, ..., `23`\n * **%I**: Hour (12-hour, zero-padded) - `01`, `02`, ..., `12`\n * **%p**: AM/PM indicator\n * **%M**: Minute (zero-padded) - `00`, `01`, ..., `59`\n * **%S**: Second (zero-padded) - `00`, `01`, ..., `59`\n * **%f**: Microsecond (zero-padded to 6 digits) - `000000`, `000001`, ..., `999999`\n * **%z**: UTC offset - `(empty)`, `+0000`, `-04:00`\n * **%Z**: Time zone name - `(empty)`, `UTC`, `GMT`\n * **%j**: Day of the year (zero-padded) - `001`, `002`, ..., `366`\n * **%U**: Week number of the year (Sunday as first day) - `00`, `01`, ..., `53`\n * **%W**: Week number of the year (Monday as first day) - `00`, `01`, ..., `53`\n * **%c**: Date and time representation - `Tue Aug 16 21:30:00 1988`\n * **%x**: Date representation - `08/16/1988`\n * **%X**: Time representation - `21:30:00`\n * **%%**: Literal \'%\' character\n\n Some placeholders depend on the locale of the underlying system - in most cases this locale is configured as en/US. For more information see the [Python documentation](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes).\n', examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s"], title="Datetime Format", ) - max_datetime: Optional[str] = Field( + max_datetime: str | None = Field( None, description="Ceiling applied on the datetime value. Must be formatted with the datetime_format field.", examples=["2021-01-01T00:00:00Z", "2021-01-01"], title="Max Datetime", ) - min_datetime: Optional[str] = Field( + min_datetime: str | None = Field( None, description="Floor applied on the datetime value. Must be formatted with the datetime_format field.", examples=["2010-01-01T00:00:00Z", "2010-01-01"], title="Min Datetime", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class NoAuth(BaseModel): type: Literal["NoAuth"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class NoPagination(BaseModel): @@ -729,7 +728,7 @@ class OAuthConfigSpecification(BaseModel): class Config: extra = Extra.allow - oauth_user_input_from_connector_config_specification: Optional[Dict[str, Any]] = Field( + oauth_user_input_from_connector_config_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations used as input to OAuth.\nMust be a valid non-nested JSON that refers to properties from ConnectorSpecification.connectionSpecification\nusing special annotation 'path_in_connector_config'.\nThese are input values the user is entering through the UI to authenticate to the connector, that might also shared\nas inputs for syncing data via the connector.\nExamples:\nif no connector values is shared during oauth flow, oauth_user_input_from_connector_config_specification=[]\nif connector values such as 'app_id' inside the top level are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['app_id']\n }\n }\nif connector values such as 'info.app_id' nested inside another object are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['info', 'app_id']\n }\n }", examples=[ @@ -743,7 +742,7 @@ class Config: ], title="OAuth user input", ) - complete_oauth_output_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_output_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations produced by the OAuth flows as they are\nreturned by the distant OAuth APIs.\nMust be a valid JSON describing the fields to merge back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_output_specification={\n refresh_token: {\n type: string,\n path_in_connector_config: ['credentials', 'refresh_token']\n }\n }", examples=[ @@ -756,13 +755,13 @@ class Config: ], title="OAuth output specification", ) - complete_oauth_server_input_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_server_input_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations.\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nserver when completing an OAuth flow (typically exchanging an auth code for refresh token).\nExamples:\n complete_oauth_server_input_specification={\n client_id: {\n type: string\n },\n client_secret: {\n type: string\n }\n }", examples=[{"client_id": {"type": "string"}, "client_secret": {"type": "string"}}], title="OAuth input specification", ) - complete_oauth_server_output_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_server_output_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations that\nalso need to be merged back into the connector configuration at runtime.\nThis is a subset configuration of `complete_oauth_server_input_specification` that filters fields out to retain only the ones that\nare necessary for the connector to function with OAuth. (some fields could be used during oauth flows but not needed afterwards, therefore\nthey would be listed in the `complete_oauth_server_input_specification` but not `complete_oauth_server_output_specification`)\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nconnector when using OAuth flow APIs.\nThese fields are to be merged back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_server_output_specification={\n client_id: {\n type: string,\n path_in_connector_config: ['credentials', 'client_id']\n },\n client_secret: {\n type: string,\n path_in_connector_config: ['credentials', 'client_secret']\n }\n }", examples=[ @@ -783,44 +782,44 @@ class Config: class OffsetIncrement(BaseModel): type: Literal["OffsetIncrement"] - page_size: Optional[Union[int, str]] = Field( + page_size: int | str | None = Field( None, description="The number of records to include in each pages.", examples=[100, "{{ config['page_size'] }}"], title="Limit", ) - inject_on_first_request: Optional[bool] = Field( + inject_on_first_request: bool | None = Field( False, description="Using the `offset` with value `0` during the first request", title="Inject Offset", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class PageIncrement(BaseModel): type: Literal["PageIncrement"] - page_size: Optional[Union[int, str]] = Field( + page_size: int | str | None = Field( None, description="The number of records to include in each pages.", examples=[100, "100", "{{ config['page_size'] }}"], title="Page Size", ) - start_from_page: Optional[int] = Field( + start_from_page: int | None = Field( 0, description="Index of the first page to request.", examples=[0, 1], title="Start From Page", ) - inject_on_first_request: Optional[bool] = Field( + inject_on_first_request: bool | None = Field( False, description="Using the `page number` with value defined by `start_from_page` during the first request", title="Inject Page Number", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class PrimaryKey(BaseModel): - __root__: Union[str, List[str], List[List[str]]] = Field( + __root__: str | list[str] | list[list[str]] = Field( ..., description="The stream field to be used to distinguish unique records. Can either be a single field, an array of fields representing a composite key, or an array of arrays representing a composite key where the fields are nested fields.", examples=["id", ["code", "type"]], @@ -830,7 +829,7 @@ class PrimaryKey(BaseModel): class RecordFilter(BaseModel): type: Literal["RecordFilter"] - condition: Optional[str] = Field( + condition: str | None = Field( "", description="The predicate to filter a record. Records will be removed if evaluated to False.", examples=[ @@ -838,7 +837,7 @@ class RecordFilter(BaseModel): "{{ record.status in ['active', 'expired'] }}", ], ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SchemaNormalization(Enum): @@ -848,7 +847,7 @@ class SchemaNormalization(Enum): class RemoveFields(BaseModel): type: Literal["RemoveFields"] - condition: Optional[str] = Field( + condition: str | None = Field( "", description="The predicate to filter a property by a property value. Property will be removed if it is empty OR expression is evaluated to True.,", examples=[ @@ -858,7 +857,7 @@ class RemoveFields(BaseModel): "{{ property == 'some_string_to_match' }}", ], ) - field_pointers: List[List[str]] = Field( + field_pointers: list[list[str]] = Field( ..., description="Array of paths defining the field to remove. Each item is an array whose field describe the path of a field to remove.", examples=[["tags"], [["content", "html"], ["content", "plain_text"]]], @@ -914,7 +913,7 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["session"], title="Login Path", ) - session_token: Optional[str] = Field( + session_token: str | None = Field( None, description="Session token to use if using a pre-defined token. Not needed if authenticating with username + password pair", example=["{{ config['session_token'] }}"], @@ -926,13 +925,13 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["id"], title="Response Token Response Key", ) - username: Optional[str] = Field( + username: str | None = Field( None, description="Username used to authenticate and obtain a session token", examples=[" {{ config['username'] }}"], title="Username", ) - password: Optional[str] = Field( + password: str | None = Field( "", description="Password used to authenticate and obtain a session token", examples=["{{ config['password'] }}", ""], @@ -944,15 +943,15 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["user/current"], title="Validate Session Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AsyncJobStatusMap(BaseModel): - type: Optional[Literal["AsyncJobStatusMap"]] = None - running: List[str] - completed: List[str] - failed: List[str] - timeout: List[str] + type: Literal["AsyncJobStatusMap"] | None = None + running: list[str] + completed: list[str] + failed: list[str] + timeout: list[str] class ValueType(Enum): @@ -970,19 +969,19 @@ class WaitTimeFromHeader(BaseModel): examples=["Retry-After"], title="Response Header Name", ) - regex: Optional[str] = Field( + regex: str | None = Field( None, description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.", examples=["([-+]?\\d+)"], title="Extraction Regex", ) - max_waiting_time_in_seconds: Optional[float] = Field( + max_waiting_time_in_seconds: float | None = Field( None, description="Given the value extracted from the header is greater than this value, stop the stream.", examples=[3600], title="Max Waiting Time in Seconds", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class WaitUntilTimeFromHeader(BaseModel): @@ -993,24 +992,24 @@ class WaitUntilTimeFromHeader(BaseModel): examples=["wait_time"], title="Response Header", ) - min_wait: Optional[Union[float, str]] = Field( + min_wait: float | str | None = Field( None, description="Minimum time to wait before retrying.", examples=[10, "60"], title="Minimum Wait Time", ) - regex: Optional[str] = Field( + regex: str | None = Field( None, description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.", examples=["([-+]?\\d+)"], title="Extraction Regex", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AddedFieldDefinition(BaseModel): type: Literal["AddedFieldDefinition"] - path: List[str] = Field( + path: list[str] = Field( ..., description="List of strings defining the path where to add the value on the record.", examples=[["segment_id"], ["metadata", "segment_id"]], @@ -1026,39 +1025,39 @@ class AddedFieldDefinition(BaseModel): ], title="Value", ) - value_type: Optional[ValueType] = Field( + value_type: ValueType | None = Field( None, description="Type of the value. If not specified, the type will be inferred from the value.", title="Value Type", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AddFields(BaseModel): type: Literal["AddFields"] - fields: List[AddedFieldDefinition] = Field( + fields: list[AddedFieldDefinition] = Field( ..., description="List of transformations (path and corresponding value) that will be added to the record.", title="Fields", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ApiKeyAuthenticator(BaseModel): type: Literal["ApiKeyAuthenticator"] - api_token: Optional[str] = Field( + api_token: str | None = Field( None, description="The API key to inject in the request. Fill it in the user inputs.", examples=["{{ config['api_key'] }}", "Token token={{ config['api_key'] }}"], title="API Key", ) - header: Optional[str] = Field( + header: str | None = Field( None, description="The name of the HTTP header that will be set to the API key. This setting is deprecated, use inject_into instead. Header and inject_into can not be defined at the same time.", examples=["Authorization", "Api-Token", "X-Auth-Token"], title="Header Name", ) - inject_into: Optional[RequestOption] = Field( + inject_into: RequestOption | None = Field( None, description="Configure how the API Key will be sent in requests to the source API. Either inject_into or header has to be defined.", examples=[ @@ -1067,26 +1066,26 @@ class ApiKeyAuthenticator(BaseModel): ], title="Inject API Key Into Outgoing HTTP Request", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AuthFlow(BaseModel): - auth_flow_type: Optional[AuthFlowType] = Field( + auth_flow_type: AuthFlowType | None = Field( None, description="The type of auth to use", title="Auth flow type" ) - predicate_key: Optional[List[str]] = Field( + predicate_key: list[str] | None = Field( None, description="JSON path to a field in the connectorSpecification that should exist for the advanced auth to be applicable.", examples=[["credentials", "auth_type"]], title="Predicate key", ) - predicate_value: Optional[str] = Field( + predicate_value: str | None = Field( None, description="Value of the predicate_key fields for the advanced auth to be applicable.", examples=["Oauth"], title="Predicate value", ) - oauth_config_specification: Optional[OAuthConfigSpecification] = None + oauth_config_specification: OAuthConfigSpecification | None = None class DatetimeBasedCursor(BaseModel): @@ -1103,129 +1102,128 @@ class DatetimeBasedCursor(BaseModel): examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s", "%ms", "%s_as_float"], title="Outgoing Datetime Format", ) - start_datetime: Union[str, MinMaxDatetime] = Field( + start_datetime: str | MinMaxDatetime = Field( ..., description="The datetime that determines the earliest record that should be synced.", examples=["2020-01-1T00:00:00Z", "{{ config['start_time'] }}"], title="Start Datetime", ) - cursor_datetime_formats: Optional[List[str]] = Field( + cursor_datetime_formats: list[str] | None = Field( None, description="The possible formats for the cursor field, in order of preference. The first format that matches the cursor field value will be used to parse it. If not provided, the `datetime_format` will be used.", title="Cursor Datetime Formats", ) - cursor_granularity: Optional[str] = Field( + cursor_granularity: str | None = Field( None, description="Smallest increment the datetime_format has (ISO 8601 duration) that is used to ensure the start of a slice does not overlap with the end of the previous one, e.g. for %Y-%m-%d the granularity should be P1D, for %Y-%m-%dT%H:%M:%SZ the granularity should be PT1S. Given this field is provided, `step` needs to be provided as well.", examples=["PT1S"], title="Cursor Granularity", ) - end_datetime: Optional[Union[str, MinMaxDatetime]] = Field( + end_datetime: str | MinMaxDatetime | None = Field( None, description="The datetime that determines the last record that should be synced. If not provided, `{{ now_utc() }}` will be used.", examples=["2021-01-1T00:00:00Z", "{{ now_utc() }}", "{{ day_delta(-1) }}"], title="End Datetime", ) - end_time_option: Optional[RequestOption] = Field( + end_time_option: RequestOption | None = Field( None, description="Optionally configures how the end datetime will be sent in requests to the source API.", title="Inject End Time Into Outgoing HTTP Request", ) - is_data_feed: Optional[bool] = Field( + is_data_feed: bool | None = Field( None, description="A data feed API is an API that does not allow filtering and paginates the content from the most recent to the least recent. Given this, the CDK needs to know when to stop paginating and this field will generate a stop condition for pagination.", title="Whether the target API is formatted as a data feed", ) - is_client_side_incremental: Optional[bool] = Field( + is_client_side_incremental: bool | None = Field( None, description="If the target API endpoint does not take cursor values to filter records and returns all records anyway, the connector with this cursor will filter out records locally, and only emit new records from the last sync, hence incremental. This means that all records would be read from the API, but only new records will be emitted to the destination.", title="Whether the target API does not support filtering and returns all data (the cursor filters records in the client instead of the API side)", ) - is_compare_strictly: Optional[bool] = Field( + is_compare_strictly: bool | None = Field( False, description="Set to True if the target API does not accept queries where the start time equal the end time.", title="Whether to skip requests if the start time equals the end time", ) - global_substream_cursor: Optional[bool] = Field( + global_substream_cursor: bool | None = Field( False, description="This setting optimizes performance when the parent stream has thousands of partitions by storing the cursor as a single value rather than per partition. Notably, the substream state is updated only at the end of the sync, which helps prevent data loss in case of a sync failure. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/incremental-syncs).", title="Whether to store cursor as one value instead of per partition", ) - lookback_window: Optional[str] = Field( + lookback_window: str | None = Field( None, description="Time interval before the start_datetime to read data for, e.g. P1M for looking back one month.", examples=["P1D", "P{{ config['lookback_days'] }}D"], title="Lookback Window", ) - partition_field_end: Optional[str] = Field( + partition_field_end: str | None = Field( None, description="Name of the partition start time field.", examples=["ending_time"], title="Partition Field End", ) - partition_field_start: Optional[str] = Field( + partition_field_start: str | None = Field( None, description="Name of the partition end time field.", examples=["starting_time"], title="Partition Field Start", ) - start_time_option: Optional[RequestOption] = Field( + start_time_option: RequestOption | None = Field( None, description="Optionally configures how the start datetime will be sent in requests to the source API.", title="Inject Start Time Into Outgoing HTTP Request", ) - step: Optional[str] = Field( + step: str | None = Field( None, description="The size of the time window (ISO8601 duration). Given this field is provided, `cursor_granularity` needs to be provided as well.", examples=["P1W", "{{ config['step_increment'] }}"], title="Step", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DefaultErrorHandler(BaseModel): type: Literal["DefaultErrorHandler"] - backoff_strategies: Optional[ - List[ - Union[ - ConstantBackoffStrategy, - CustomBackoffStrategy, - ExponentialBackoffStrategy, - WaitTimeFromHeader, - WaitUntilTimeFromHeader, - ] + backoff_strategies: ( + list[ + ConstantBackoffStrategy + | CustomBackoffStrategy + | ExponentialBackoffStrategy + | WaitTimeFromHeader + | WaitUntilTimeFromHeader ] - ] = Field( + | None + ) = Field( None, description="List of backoff strategies to use to determine how long to wait before retrying a retryable request.", title="Backoff Strategies", ) - max_retries: Optional[int] = Field( + max_retries: int | None = Field( 5, description="The maximum number of time to retry a retryable request before giving up and failing.", examples=[5, 0, 10], title="Max Retry Count", ) - response_filters: Optional[List[HttpResponseFilter]] = Field( + response_filters: list[HttpResponseFilter] | None = Field( None, description="List of response filters to iterate on when deciding how to handle an error. When using an array of multiple filters, the filters will be applied sequentially and the response will be selected if it matches any of the filter's predicate.", title="Response Filters", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DefaultPaginator(BaseModel): type: Literal["DefaultPaginator"] - pagination_strategy: Union[ - CursorPagination, CustomPaginationStrategy, OffsetIncrement, PageIncrement - ] = Field( + pagination_strategy: ( + CursorPagination | CustomPaginationStrategy | OffsetIncrement | PageIncrement + ) = Field( ..., description="Strategy defining how records are paginated.", title="Pagination Strategy", ) - page_size_option: Optional[RequestOption] = None - page_token_option: Optional[Union[RequestOption, RequestPath]] = None - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + page_size_option: RequestOption | None = None + page_token_option: RequestOption | RequestPath | None = None + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenRequestApiKeyAuthenticator(BaseModel): @@ -1249,46 +1247,46 @@ class ListPartitionRouter(BaseModel): examples=["section", "{{ config['section_key'] }}"], title="Current Partition Value Identifier", ) - values: Union[str, List[str]] = Field( + values: str | list[str] = Field( ..., description="The list of attributes being iterated over and used as input for the requests made to the source API.", examples=[["section_a", "section_b", "section_c"], "{{ config['sections'] }}"], title="Partition Values", ) - request_option: Optional[RequestOption] = Field( + request_option: RequestOption | None = Field( None, description="A request option describing where the list value should be injected into and under what field name if applicable.", title="Inject Partition Value Into Outgoing HTTP Request", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class RecordSelector(BaseModel): type: Literal["RecordSelector"] - extractor: Union[CustomRecordExtractor, DpathExtractor] - record_filter: Optional[Union[CustomRecordFilter, RecordFilter]] = Field( + extractor: CustomRecordExtractor | DpathExtractor + record_filter: CustomRecordFilter | RecordFilter | None = Field( None, description="Responsible for filtering records to be emitted by the Source.", title="Record Filter", ) - schema_normalization: Optional[SchemaNormalization] = SchemaNormalization.None_ - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + schema_normalization: SchemaNormalization | None = SchemaNormalization.None_ + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class Spec(BaseModel): type: Literal["Spec"] - connection_specification: Dict[str, Any] = Field( + connection_specification: dict[str, Any] = Field( ..., description="A connection specification describing how a the connector can be configured.", title="Connection Specification", ) - documentation_url: Optional[str] = Field( + documentation_url: str | None = Field( None, description="URL of the connector's documentation page.", examples=["https://docs.airbyte.com/integrations/sources/dremio"], title="Documentation URL", ) - advanced_auth: Optional[AuthFlow] = Field( + advanced_auth: AuthFlow | None = Field( None, description="Advanced specification for configuring the authentication flow.", title="Advanced Auth", @@ -1297,12 +1295,12 @@ class Spec(BaseModel): class CompositeErrorHandler(BaseModel): type: Literal["CompositeErrorHandler"] - error_handlers: List[Union[CompositeErrorHandler, DefaultErrorHandler]] = Field( + error_handlers: list[CompositeErrorHandler | DefaultErrorHandler] = Field( ..., description="List of error handlers to iterate on to determine how to handle a failed response.", title="Error Handlers", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DeclarativeSource(BaseModel): @@ -1311,20 +1309,20 @@ class Config: type: Literal["DeclarativeSource"] check: CheckStream - streams: List[DeclarativeStream] + streams: list[DeclarativeStream] version: str = Field( ..., description="The version of the Airbyte CDK used to build and test the source.", ) - schemas: Optional[Schemas] = None - definitions: Optional[Dict[str, Any]] = None - spec: Optional[Spec] = None - concurrency_level: Optional[ConcurrencyLevel] = None - metadata: Optional[Dict[str, Any]] = Field( + schemas: Schemas | None = None + definitions: dict[str, Any] | None = None + spec: Spec | None = None + concurrency_level: ConcurrencyLevel | None = None + metadata: dict[str, Any] | None = Field( None, description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.", ) - description: Optional[str] = Field( + description: str | None = Field( None, description="A description of the connector. It will be presented on the Source documentation page.", ) @@ -1335,25 +1333,23 @@ class Config: extra = Extra.allow type: Literal["SelectiveAuthenticator"] - authenticator_selection_path: List[str] = Field( + authenticator_selection_path: list[str] = Field( ..., description="Path of the field in config with selected authenticator name", examples=[["auth"], ["auth", "type"]], title="Authenticator Selection Path", ) - authenticators: Dict[ + authenticators: dict[ str, - Union[ - ApiKeyAuthenticator, - BasicHttpAuthenticator, - BearerAuthenticator, - CustomAuthenticator, - OAuthAuthenticator, - JwtAuthenticator, - NoAuth, - SessionTokenAuthenticator, - LegacySessionTokenAuthenticator, - ], + ApiKeyAuthenticator + | BasicHttpAuthenticator + | BearerAuthenticator + | CustomAuthenticator + | OAuthAuthenticator + | JwtAuthenticator + | NoAuth + | SessionTokenAuthenticator + | LegacySessionTokenAuthenticator, ] = Field( ..., description="Authenticators to select from.", @@ -1368,7 +1364,7 @@ class Config: ], title="Authenticators", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DeclarativeStream(BaseModel): @@ -1376,42 +1372,40 @@ class Config: extra = Extra.allow type: Literal["DeclarativeStream"] - retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field( + retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field( ..., description="Component used to coordinate how records are extracted across stream slices and request pages.", title="Retriever", ) - incremental_sync: Optional[Union[CustomIncrementalSync, DatetimeBasedCursor]] = Field( + incremental_sync: CustomIncrementalSync | DatetimeBasedCursor | None = Field( None, description="Component used to fetch data incrementally based on a time field in the data.", title="Incremental Sync", ) - name: Optional[str] = Field("", description="The stream name.", example=["Users"], title="Name") - primary_key: Optional[PrimaryKey] = Field( + name: str | None = Field("", description="The stream name.", example=["Users"], title="Name") + primary_key: PrimaryKey | None = Field( "", description="The primary key of the stream.", title="Primary Key" ) - schema_loader: Optional[Union[InlineSchemaLoader, JsonFileSchemaLoader, CustomSchemaLoader]] = ( + schema_loader: InlineSchemaLoader | JsonFileSchemaLoader | CustomSchemaLoader | None = Field( + None, + description="Component used to retrieve the schema for the current stream.", + title="Schema Loader", + ) + transformations: list[AddFields | CustomTransformation | RemoveFields | KeysToLower] | None = ( Field( None, - description="Component used to retrieve the schema for the current stream.", - title="Schema Loader", + description="A list of transformations to be applied to each output record.", + title="Transformations", ) ) - transformations: Optional[ - List[Union[AddFields, CustomTransformation, RemoveFields, KeysToLower]] - ] = Field( - None, - description="A list of transformations to be applied to each output record.", - title="Transformations", - ) - state_migrations: Optional[ - List[Union[LegacyToPerPartitionStateMigration, CustomStateMigration]] - ] = Field( - [], - description="Array of state migrations to be applied on the input state", - title="State Migrations", + state_migrations: list[LegacyToPerPartitionStateMigration | CustomStateMigration] | None = ( + Field( + [], + description="Array of state migrations to be applied on the input state", + title="State Migrations", + ) ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenAuthenticator(BaseModel): @@ -1433,29 +1427,29 @@ class SessionTokenAuthenticator(BaseModel): ], title="Login Requester", ) - session_token_path: List[str] = Field( + session_token_path: list[str] = Field( ..., description="The path in the response body returned from the login requester to the session token.", examples=[["access_token"], ["result", "token"]], title="Session Token Path", ) - expiration_duration: Optional[str] = Field( + expiration_duration: str | None = Field( None, description="The duration in ISO 8601 duration notation after which the session token expires, starting from the time it was obtained. Omitting it will result in the session token being refreshed for every request.", examples=["PT1H", "P1D"], title="Expiration Duration", ) - request_authentication: Union[ - SessionTokenRequestApiKeyAuthenticator, SessionTokenRequestBearerAuthenticator - ] = Field( + request_authentication: ( + SessionTokenRequestApiKeyAuthenticator | SessionTokenRequestBearerAuthenticator + ) = Field( ..., description="Authentication method to use for requests sent to the API, specifying how to inject the session token.", title="Data Request Authentication", ) - decoder: Optional[Union[JsonDecoder, XmlDecoder]] = Field( + decoder: JsonDecoder | XmlDecoder | None = Field( None, description="Component used to decode the response.", title="Decoder" ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class HttpRequester(BaseModel): @@ -1479,38 +1473,35 @@ class HttpRequester(BaseModel): ], title="URL Path", ) - authenticator: Optional[ - Union[ - ApiKeyAuthenticator, - BasicHttpAuthenticator, - BearerAuthenticator, - CustomAuthenticator, - OAuthAuthenticator, - JwtAuthenticator, - NoAuth, - SessionTokenAuthenticator, - LegacySessionTokenAuthenticator, - SelectiveAuthenticator, - ] - ] = Field( + authenticator: ( + ApiKeyAuthenticator + | BasicHttpAuthenticator + | BearerAuthenticator + | CustomAuthenticator + | OAuthAuthenticator + | JwtAuthenticator + | NoAuth + | SessionTokenAuthenticator + | LegacySessionTokenAuthenticator + | SelectiveAuthenticator + | None + ) = Field( None, description="Authentication method to use for requests sent to the API.", title="Authenticator", ) - error_handler: Optional[ - Union[DefaultErrorHandler, CustomErrorHandler, CompositeErrorHandler] - ] = Field( + error_handler: DefaultErrorHandler | CustomErrorHandler | CompositeErrorHandler | None = Field( None, description="Error handler component that defines how to handle errors.", title="Error Handler", ) - http_method: Optional[HttpMethod] = Field( + http_method: HttpMethod | None = Field( HttpMethod.GET, description="The HTTP method used to fetch data from the source (can be GET or POST).", examples=["GET", "POST"], title="HTTP Method", ) - request_body_data: Optional[Union[str, Dict[str, str]]] = Field( + request_body_data: str | dict[str, str] | None = Field( None, description="Specifies how to populate the body of the request with a non-JSON payload. Plain text will be sent as is, whereas objects will be converted to a urlencoded form.", examples=[ @@ -1518,7 +1509,7 @@ class HttpRequester(BaseModel): ], title="Request Body Payload (Non-JSON)", ) - request_body_json: Optional[Union[str, Dict[str, Any]]] = Field( + request_body_json: str | dict[str, Any] | None = Field( None, description="Specifies how to populate the body of the request with a JSON payload. Can contain nested objects.", examples=[ @@ -1528,13 +1519,13 @@ class HttpRequester(BaseModel): ], title="Request Body JSON Payload", ) - request_headers: Optional[Union[str, Dict[str, str]]] = Field( + request_headers: str | dict[str, str] | None = Field( None, description="Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.", examples=[{"Output-Format": "JSON"}, {"Version": "{{ config['version'] }}"}], title="Request Headers", ) - request_parameters: Optional[Union[str, Dict[str, str]]] = Field( + request_parameters: str | dict[str, str] | None = Field( None, description="Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.", examples=[ @@ -1547,12 +1538,12 @@ class HttpRequester(BaseModel): ], title="Query Parameters", ) - use_cache: Optional[bool] = Field( + use_cache: bool | None = Field( False, description="Enables stream requests caching. This field is automatically set by the CDK.", title="Use Cache", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ParentStreamConfig(BaseModel): @@ -1572,22 +1563,22 @@ class ParentStreamConfig(BaseModel): examples=["parent_id", "{{ config['parent_partition_field'] }}"], title="Current Parent Key Value Identifier", ) - request_option: Optional[RequestOption] = Field( + request_option: RequestOption | None = Field( None, description="A request option describing where the parent key value should be injected into and under what field name if applicable.", title="Request Option", ) - incremental_dependency: Optional[bool] = Field( + incremental_dependency: bool | None = Field( False, description="Indicates whether the parent stream should be read incrementally based on updates in the child stream.", title="Incremental Dependency", ) - extra_fields: Optional[List[List[str]]] = Field( + extra_fields: list[list[str]] | None = Field( None, description="Array of field paths to include as additional fields in the stream slice. Each path is an array of strings representing keys to access fields in the respective parent record. Accessible via `stream_slice.extra_fields`. Missing fields are set to `None`.", title="Extra Fields", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SimpleRetriever(BaseModel): @@ -1596,36 +1587,35 @@ class SimpleRetriever(BaseModel): ..., description="Component that describes how to extract records from a HTTP response.", ) - requester: Union[CustomRequester, HttpRequester] = Field( + requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API.", ) - paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field( + paginator: DefaultPaginator | NoPagination | None = Field( None, description="Paginator component that describes how to navigate through the API's pages.", ) - ignore_stream_slicer_parameters_on_paginated_requests: Optional[bool] = Field( + ignore_stream_slicer_parameters_on_paginated_requests: bool | None = Field( False, description="If true, the partition router and incremental request options will be ignored when paginating requests. Request options set directly on the requester will not be ignored.", ) - partition_router: Optional[ - Union[ - CustomPartitionRouter, - ListPartitionRouter, - SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], - ] - ] = Field( + partition_router: ( + CustomPartitionRouter + | ListPartitionRouter + | SubstreamPartitionRouter + | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter] + | None + ) = Field( [], description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.", title="Partition Router", ) - decoder: Optional[Union[JsonDecoder, JsonlDecoder, IterableDecoder, XmlDecoder]] = Field( + decoder: JsonDecoder | JsonlDecoder | IterableDecoder | XmlDecoder | None = Field( None, description="Component decoding the response so records can be extracted.", title="Decoder", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AsyncRetriever(BaseModel): @@ -1637,65 +1627,64 @@ class AsyncRetriever(BaseModel): status_mapping: AsyncJobStatusMap = Field( ..., description="Async Job Status to Airbyte CDK Async Job Status mapping." ) - status_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field( + status_extractor: CustomRecordExtractor | DpathExtractor = Field( ..., description="Responsible for fetching the actual status of the async job." ) - urls_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field( + urls_extractor: CustomRecordExtractor | DpathExtractor = Field( ..., description="Responsible for fetching the final result `urls` provided by the completed / finished / ready async job.", ) - creation_requester: Union[CustomRequester, HttpRequester] = Field( + creation_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to create the async server-side job.", ) - polling_requester: Union[CustomRequester, HttpRequester] = Field( + polling_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to fetch the status of the running async job.", ) - download_requester: Union[CustomRequester, HttpRequester] = Field( + download_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to download the data provided by the completed async job.", ) - download_paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field( + download_paginator: DefaultPaginator | NoPagination | None = Field( None, description="Paginator component that describes how to navigate through the API's pages during download.", ) - abort_requester: Optional[Union[CustomRequester, HttpRequester]] = Field( + abort_requester: CustomRequester | HttpRequester | None = Field( None, description="Requester component that describes how to prepare HTTP requests to send to the source API to abort a job once it is timed out from the source's perspective.", ) - delete_requester: Optional[Union[CustomRequester, HttpRequester]] = Field( + delete_requester: CustomRequester | HttpRequester | None = Field( None, description="Requester component that describes how to prepare HTTP requests to send to the source API to delete a job once the records are extracted.", ) - partition_router: Optional[ - Union[ - CustomPartitionRouter, - ListPartitionRouter, - SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], - ] - ] = Field( + partition_router: ( + CustomPartitionRouter + | ListPartitionRouter + | SubstreamPartitionRouter + | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter] + | None + ) = Field( [], description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.", title="Partition Router", ) - decoder: Optional[Union[JsonDecoder, JsonlDecoder, IterableDecoder, XmlDecoder]] = Field( + decoder: JsonDecoder | JsonlDecoder | IterableDecoder | XmlDecoder | None = Field( None, description="Component decoding the response so records can be extracted.", title="Decoder", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SubstreamPartitionRouter(BaseModel): type: Literal["SubstreamPartitionRouter"] - parent_stream_configs: List[ParentStreamConfig] = Field( + parent_stream_configs: list[ParentStreamConfig] = Field( ..., description="Specifies which parent streams are being iterated over and how parent records should be used to partition the child stream data set.", title="Parent Stream Configs", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") CompositeErrorHandler.update_forward_refs() diff --git a/airbyte_cdk/sources/declarative/parsers/custom_exceptions.py b/airbyte_cdk/sources/declarative/parsers/custom_exceptions.py index d6fdee69..c0db9ca0 100644 --- a/airbyte_cdk/sources/declarative/parsers/custom_exceptions.py +++ b/airbyte_cdk/sources/declarative/parsers/custom_exceptions.py @@ -1,21 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations class CircularReferenceException(Exception): - """ - Raised when a circular reference is detected in a manifest. - """ + """Raised when a circular reference is detected in a manifest.""" def __init__(self, reference: str) -> None: super().__init__(f"Circular reference found: {reference}") class UndefinedReferenceException(Exception): - """ - Raised when refering to an undefined reference. - """ + """Raised when refering to an undefined reference.""" def __init__(self, path: str, reference: str) -> None: super().__init__(f"Undefined reference {reference} from {path}") diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py index 8cacda3d..978f037c 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy -import typing -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any + PARAMETERS_STR = "$parameters" @@ -82,8 +84,7 @@ def propagate_types_and_parameters( declarative_component: Mapping[str, Any], parent_parameters: Mapping[str, Any], ) -> Mapping[str, Any]: - """ - Recursively transforms the specified declarative component and subcomponents to propagate parameters and insert the + """Recursively transforms the specified declarative component and subcomponents to propagate parameters and insert the default component type if it was not already present. The resulting transformed components are a deep copy of the input components, not an in-place transformation. @@ -136,7 +137,7 @@ def propagate_types_and_parameters( ) if excluded_parameter: current_parameters[field_name] = excluded_parameter - elif isinstance(field_value, typing.List): + elif isinstance(field_value, list): # We exclude propagating a parameter that matches the current field name because that would result in an infinite cycle excluded_parameter = current_parameters.pop(field_name, None) for i, element in enumerate(field_value): diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py index 045ea9a2..7e1ae263 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py @@ -1,21 +1,23 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import re -from typing import Any, Mapping, Set, Tuple, Union +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.declarative.parsers.custom_exceptions import ( CircularReferenceException, UndefinedReferenceException, ) + REF_TAG = "$ref" class ManifestReferenceResolver: - """ - An incoming manifest can contain references to values previously defined. + """An incoming manifest can contain references to values previously defined. This parser will dereference these values to produce a complete ConnectionDefinition. References can be defined using a #/ string. @@ -100,13 +102,12 @@ class ManifestReferenceResolver: """ def preprocess_manifest(self, manifest: Mapping[str, Any]) -> Mapping[str, Any]: - """ - :param manifest: incoming manifest that could have references to previously defined components + """:param manifest: incoming manifest that could have references to previously defined components :return: """ return self._evaluate_node(manifest, manifest, set()) # type: ignore[no-any-return] - def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[Any]) -> Any: + def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: set[Any]) -> Any: if isinstance(node, dict): evaluated_dict = { k: self._evaluate_node(v, manifest, visited) @@ -118,22 +119,19 @@ def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[An evaluated_ref = self._evaluate_node(node[REF_TAG], manifest, visited) if not isinstance(evaluated_ref, dict): return evaluated_ref - else: - # The values defined on the component take precedence over the reference values - return evaluated_ref | evaluated_dict - else: - return evaluated_dict - elif isinstance(node, list): + # The values defined on the component take precedence over the reference values + return evaluated_ref | evaluated_dict + return evaluated_dict + if isinstance(node, list): return [self._evaluate_node(v, manifest, visited) for v in node] - elif self._is_ref(node): + if self._is_ref(node): if node in visited: raise CircularReferenceException(node) visited.add(node) ret = self._evaluate_node(self._lookup_ref_value(node, manifest), manifest, visited) visited.remove(node) return ret - else: - return node + return node def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any: ref_match = re.match(r"#/(.*)", ref) @@ -155,8 +153,7 @@ def _is_ref_key(key: str) -> bool: @staticmethod def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: - """ - Read the value at the referenced location of the manifest. + """Read the value at the referenced location of the manifest. References are ambiguous because one could define a key containing `/` In this example, we want to refer to the `limit` key in the `dict` object: @@ -185,9 +182,8 @@ def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: return manifest_node -def _parse_path(ref: str) -> Tuple[Union[str, int], str]: - """ - Return the next path component, together with the rest of the path. +def _parse_path(ref: str) -> tuple[str | int, str]: + """Return the next path component, together with the rest of the path. A path component may be a string key, or an int index. diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 2812ba81..243b6bdb 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -8,23 +8,18 @@ import importlib import inspect import re +from collections.abc import Callable, Mapping, MutableMapping from functools import partial from typing import ( Any, - Callable, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - Type, - Union, get_args, get_origin, get_type_hints, ) +from isodate import parse_duration +from pydantic.v1 import BaseModel + from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator @@ -360,8 +355,7 @@ from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction from airbyte_cdk.sources.types import Config from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer -from isodate import parse_duration -from pydantic.v1 import BaseModel + ComponentDefinition = Mapping[str, Any] @@ -371,12 +365,12 @@ class ModelToComponentFactory: def __init__( self, - limit_pages_fetched_per_slice: Optional[int] = None, - limit_slices_fetched: Optional[int] = None, + limit_pages_fetched_per_slice: int | None = None, + limit_slices_fetched: int | None = None, emit_connector_builder_messages: bool = False, disable_retries: bool = False, disable_cache: bool = False, - message_repository: Optional[MessageRepository] = None, + message_repository: MessageRepository | None = None, ): self._init_mappings() self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice @@ -389,7 +383,7 @@ def __init__( ) def _init_mappings(self) -> None: - self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = { + self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[type[BaseModel], Callable[..., Any]] = { AddedFieldDefinitionModel: self.create_added_field_definition, AddFieldsModel: self.create_add_fields, ApiKeyAuthenticatorModel: self.create_api_key_authenticator, @@ -459,13 +453,12 @@ def _init_mappings(self) -> None: def create_component( self, - model_type: Type[BaseModel], + model_type: type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs: Any, ) -> Any: - """ - Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and + """Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and subcomponents which will be used at runtime. This is done by first parsing the mapping into a Pydantic model and then creating creating declarative components from that model. @@ -474,7 +467,6 @@ def create_component( :param config: The connector config that is provided by the customer :return: The declarative component to be used at runtime """ - component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: raise ValueError( @@ -535,7 +527,7 @@ def create_keys_to_lower_transformation( return KeysToLowerTransformation() @staticmethod - def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[Type[Any]]: + def _json_schema_type_name_to_type(value_type: ValueType | None) -> type[Any] | None: if not value_type: return None names_to_types = { @@ -550,7 +542,7 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[ def create_api_key_authenticator( model: ApiKeyAuthenticatorModel, config: Config, - token_provider: Optional[TokenProvider] = None, + token_provider: TokenProvider | None = None, **kwargs: Any, ) -> ApiKeyAuthenticator: if model.inject_into is None and model.header is None: @@ -628,7 +620,7 @@ def create_legacy_to_per_partition_state_migration( def create_session_token_authenticator( self, model: SessionTokenAuthenticatorModel, config: Config, name: str, **kwargs: Any - ) -> Union[ApiKeyAuthenticator, BearerAuthenticator]: + ) -> ApiKeyAuthenticator | BearerAuthenticator: decoder = ( self._create_component_from_model(model=model.decoder, config=config) if model.decoder @@ -656,16 +648,15 @@ def create_session_token_authenticator( config, token_provider=token_provider, # type: ignore # $parameters defaults to None ) - else: - return ModelToComponentFactory.create_api_key_authenticator( - ApiKeyAuthenticatorModel( - type="ApiKeyAuthenticator", - api_token="", - inject_into=model.request_authentication.inject_into, - ), # type: ignore # $parameters and headers default to None - config=config, - token_provider=token_provider, - ) + return ModelToComponentFactory.create_api_key_authenticator( + ApiKeyAuthenticatorModel( + type="ApiKeyAuthenticator", + api_token="", + inject_into=model.request_authentication.inject_into, + ), # type: ignore # $parameters and headers default to None + config=config, + token_provider=token_provider, + ) @staticmethod def create_basic_http_authenticator( @@ -682,7 +673,7 @@ def create_basic_http_authenticator( def create_bearer_authenticator( model: BearerAuthenticatorModel, config: Config, - token_provider: Optional[TokenProvider] = None, + token_provider: TokenProvider | None = None, **kwargs: Any, ) -> BearerAuthenticator: if token_provider is not None and model.api_token != "": @@ -732,14 +723,14 @@ def create_concurrency_level( def create_concurrent_cursor_from_datetime_based_cursor( self, state_manager: ConnectorStateManager, - model_type: Type[BaseModel], + model_type: type[BaseModel], component_definition: ComponentDefinition, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, config: Config, stream_state: MutableMapping[str, Any], **kwargs: Any, - ) -> Tuple[ConcurrentCursor, DateTimeStreamStateConverter]: + ) -> tuple[ConcurrentCursor, DateTimeStreamStateConverter]: component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: raise ValueError( @@ -804,7 +795,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # Having issues w/ inspection for GapType and CursorValueType as shown in existing tests. Confirmed functionality is working in practice ) - start_date_runtime_value: Union[InterpolatedString, str, MinMaxDatetime] + start_date_runtime_value: InterpolatedString | str | MinMaxDatetime if isinstance(datetime_based_cursor_model.start_datetime, MinMaxDatetimeModel): start_date_runtime_value = self.create_min_max_datetime( model=datetime_based_cursor_model.start_datetime, config=config @@ -812,7 +803,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( else: start_date_runtime_value = datetime_based_cursor_model.start_datetime - end_date_runtime_value: Optional[Union[InterpolatedString, str, MinMaxDatetime]] + end_date_runtime_value: InterpolatedString | str | MinMaxDatetime | None if isinstance(datetime_based_cursor_model.end_datetime, MinMaxDatetimeModel): end_date_runtime_value = self.create_min_max_datetime( model=datetime_based_cursor_model.end_datetime, config=config @@ -925,14 +916,12 @@ def create_cursor_pagination( ) def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any: - """ - Generically creates a custom component based on the model type and a class_name reference to the custom Python class being + """Generically creates a custom component based on the model type and a class_name reference to the custom Python class being instantiated. Only the model's additional properties that match the custom class definition are passed to the constructor :param model: The Pydantic model of the custom component being created :param config: The custom defined connector config :return: The declarative component built from the Pydantic model to be used at runtime """ - custom_component_class = self._get_class_from_fully_qualified_class_name(model.class_name) component_fields = get_type_hints(custom_component_class) model_args = model.dict() @@ -996,7 +985,7 @@ def _get_class_from_fully_qualified_class_name(full_qualified_class_name: str) - raise ValueError(f"Could not load class {full_qualified_class_name}.") @staticmethod - def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]: + def _derive_component_type_from_type_hints(field_type: Any) -> str | None: interface = field_type while True: origin = get_origin(interface) @@ -1013,18 +1002,17 @@ def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]: return None @staticmethod - def is_builtin_type(cls: Optional[Type[Any]]) -> bool: + def is_builtin_type(cls: type[Any] | None) -> bool: if not cls: return False return cls.__module__ == "builtins" @staticmethod - def _extract_missing_parameters(error: TypeError) -> List[str]: + def _extract_missing_parameters(error: TypeError) -> list[str]: parameter_search = re.search(r"keyword-only.*:\s(.*)", str(error)) if parameter_search: return re.findall(r"\'(.+?)\'", parameter_search.group(1)) - else: - return [] + return [] def _create_nested_component( self, model: Any, model_field: str, model_value: Any, config: Config @@ -1061,10 +1049,8 @@ def _create_nested_component( raise ValueError( f"Error creating component '{type_name}' with parent custom component {model.class_name}: Please provide " + ", ".join( - ( - f"{type_name}.$parameters.{parameter}" - for parameter in missing_parameters - ) + f"{type_name}.$parameters.{parameter}" + for parameter in missing_parameters ) ) raise TypeError( @@ -1082,12 +1068,12 @@ def _is_component(model_value: Any) -> bool: def create_datetime_based_cursor( self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any ) -> DatetimeBasedCursor: - start_datetime: Union[str, MinMaxDatetime] = ( + start_datetime: str | MinMaxDatetime = ( model.start_datetime if isinstance(model.start_datetime, str) else self.create_min_max_datetime(model.start_datetime, config) ) - end_datetime: Union[str, MinMaxDatetime, None] = None + end_datetime: str | MinMaxDatetime | None = None if model.is_data_feed and model.end_datetime: raise ValueError("Data feed does not support end_datetime") if model.is_data_feed and model.is_client_side_incremental: @@ -1122,9 +1108,7 @@ def create_datetime_based_cursor( return DatetimeBasedCursor( cursor_field=model.cursor_field, - cursor_datetime_formats=model.cursor_datetime_formats - if model.cursor_datetime_formats - else [], + cursor_datetime_formats=model.cursor_datetime_formats or [], cursor_granularity=model.cursor_granularity, datetime_format=model.datetime_format, end_datetime=end_datetime, @@ -1267,7 +1251,7 @@ def create_declarative_stream( def _merge_stream_slicers( self, model: DeclarativeStreamModel, config: Config - ) -> Optional[StreamSlicer]: + ) -> StreamSlicer | None: stream_slicer = None if ( hasattr(model.retriever, "partition_router") @@ -1301,26 +1285,25 @@ def _merge_stream_slicers( return GlobalSubstreamCursor( stream_cursor=cursor_component, partition_router=stream_slicer ) - else: - cursor_component = self._create_component_from_model( - model=incremental_sync_model, config=config - ) - return PerPartitionWithGlobalCursor( - cursor_factory=CursorFactory( - lambda: self._create_component_from_model( - model=incremental_sync_model, config=config - ), + cursor_component = self._create_component_from_model( + model=incremental_sync_model, config=config + ) + return PerPartitionWithGlobalCursor( + cursor_factory=CursorFactory( + lambda: self._create_component_from_model( + model=incremental_sync_model, config=config ), - partition_router=stream_slicer, - stream_cursor=cursor_component, - ) - elif model.incremental_sync: + ), + partition_router=stream_slicer, + stream_cursor=cursor_component, + ) + if model.incremental_sync: return ( self._create_component_from_model(model=model.incremental_sync, config=config) if model.incremental_sync else None ) - elif stream_slicer: + if stream_slicer: # For the Full-Refresh sub-streams, we use the nested `ChildPartitionResumableFullRefreshCursor` return PerPartitionCursor( cursor_factory=CursorFactory( @@ -1328,15 +1311,14 @@ def _merge_stream_slicers( ), partition_router=stream_slicer, ) - elif ( + if ( hasattr(model.retriever, "paginator") and model.retriever.paginator and not stream_slicer ): # For the regular Full-Refresh streams, we use the high level `ResumableFullRefreshCursor` return ResumableFullRefreshCursor(parameters={}) - else: - return None + return None def create_default_error_handler( self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any @@ -1372,9 +1354,9 @@ def create_default_paginator( config: Config, *, url_base: str, - decoder: Optional[Decoder] = None, - cursor_used_for_stop_condition: Optional[DeclarativeCursor] = None, - ) -> Union[DefaultPaginator, PaginatorTestReadDecorator]: + decoder: Decoder | None = None, + cursor_used_for_stop_condition: DeclarativeCursor | None = None, + ) -> DefaultPaginator | PaginatorTestReadDecorator: if decoder: if not isinstance(decoder, (JsonDecoder, XmlDecoder)): raise ValueError( @@ -1417,14 +1399,14 @@ def create_dpath_extractor( self, model: DpathExtractorModel, config: Config, - decoder: Optional[Decoder] = None, + decoder: Decoder | None = None, **kwargs: Any, ) -> DpathExtractor: if decoder: decoder_to_use = decoder else: decoder_to_use = JsonDecoder(parameters={}) - model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path] + model_field_path: list[InterpolatedString | str] = [x for x in model.field_path] return DpathExtractor( decoder=decoder_to_use, field_path=model_field_path, @@ -1761,9 +1743,9 @@ def create_record_selector( model: RecordSelectorModel, config: Config, *, - transformations: List[RecordTransformation], - decoder: Optional[Decoder] = None, - client_side_incremental_sync: Optional[Dict[str, Any]] = None, + transformations: list[RecordTransformation], + decoder: Decoder | None = None, + client_side_incremental_sync: dict[str, Any] | None = None, **kwargs: Any, ) -> RecordSelector: assert model.schema_normalization is not None # for mypy @@ -1843,12 +1825,12 @@ def create_simple_retriever( config: Config, *, name: str, - primary_key: Optional[Union[str, List[str], List[List[str]]]], - stream_slicer: Optional[StreamSlicer], - request_options_provider: Optional[RequestOptionsProvider] = None, + primary_key: str | list[str] | list[list[str]] | None, + stream_slicer: StreamSlicer | None, + request_options_provider: RequestOptionsProvider | None = None, stop_condition_on_cursor: bool = False, - client_side_incremental_sync: Optional[Dict[str, Any]] = None, - transformations: List[RecordTransformation], + client_side_incremental_sync: dict[str, Any] | None = None, + transformations: list[RecordTransformation], ) -> SimpleRetriever: decoder = ( self._create_component_from_model(model=model.decoder, config=config) @@ -1970,12 +1952,13 @@ def create_async_retriever( config: Config, *, name: str, - primary_key: Optional[ - Union[str, List[str], List[List[str]]] - ], # this seems to be needed to match create_simple_retriever - stream_slicer: Optional[StreamSlicer], - client_side_incremental_sync: Optional[Dict[str, Any]] = None, - transformations: List[RecordTransformation], + primary_key: str + | list[str] + | list[list[str]] + | None, # this seems to be needed to match create_simple_retriever + stream_slicer: StreamSlicer | None, + client_side_incremental_sync: dict[str, Any] | None = None, + transformations: list[RecordTransformation], **kwargs: Any, ) -> AsyncRetriever: decoder = ( diff --git a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py index 8718004b..6278cb3a 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py +++ b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py @@ -1,13 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import itertools import logging from collections import ChainMap -from collections.abc import Callable +from collections.abc import Callable, Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( @@ -19,8 +20,7 @@ def check_for_substream_in_slicers( slicers: Iterable[PartitionRouter], log_warning: Callable[[str], None] ) -> None: - """ - Recursively checks for the presence of SubstreamPartitionRouter within slicers. + """Recursively checks for the presence of SubstreamPartitionRouter within slicers. Logs a warning if a SubstreamPartitionRouter is found within a CartesianProductStreamSlicer. Args: @@ -31,15 +31,14 @@ def check_for_substream_in_slicers( if isinstance(slicer, SubstreamPartitionRouter): log_warning("Parent state handling is not supported for CartesianProductStreamSlicer.") return - elif isinstance(slicer, CartesianProductStreamSlicer): + if isinstance(slicer, CartesianProductStreamSlicer): # Recursively check sub-slicers within CartesianProductStreamSlicer check_for_substream_in_slicers(slicer.stream_slicers, log_warning) @dataclass class CartesianProductStreamSlicer(PartitionRouter): - """ - Stream slicers that iterates over the cartesian product of input stream slicers + """Stream slicers that iterates over the cartesian product of input stream slicers Given 2 stream slicers with the following slices: A: [{"i": 0}, {"i": 1}, {"i": 2}] B: [{"s": "hello"}, {"s": "world"}] @@ -57,7 +56,7 @@ class CartesianProductStreamSlicer(PartitionRouter): stream_slicers (List[PartitionRouter]): Underlying stream slicers. The RequestOptions (e.g: Request headers, parameters, etc..) returned by this slicer are the combination of the RequestOptions of its input slicers. If there are conflicts e.g: two slicers define the same header or request param, the conflict is resolved by taking the value from the first slicer, where ordering is determined by the order in which slicers were input to this composite slicer. """ - stream_slicers: List[PartitionRouter] + stream_slicers: list[PartitionRouter] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -66,9 +65,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( ChainMap( @@ -86,9 +85,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( ChainMap( @@ -106,9 +105,9 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( ChainMap( @@ -126,9 +125,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( ChainMap( @@ -160,15 +159,11 @@ def stream_slices(self) -> Iterable[StreamSlice]: yield StreamSlice(partition=partition, cursor_slice=cursor_slice) def set_initial_state(self, stream_state: StreamState) -> None: - """ - Parent stream states are not supported for cartesian product stream slicer - """ + """Parent stream states are not supported for cartesian product stream slicer""" pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: - """ - Parent stream states are not supported for cartesian product stream slicer - """ + def get_stream_state(self) -> Mapping[str, StreamState] | None: + """Parent stream states are not supported for cartesian product stream slicer""" pass @property diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 29b700b0..d701d6b6 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -16,8 +18,7 @@ @dataclass class ListPartitionRouter(PartitionRouter): - """ - Partition router that iterates over the values of a list + """Partition router that iterates over the values of a list If values is a string, then evaluate it as literal and assert the resulting literal is a list Attributes: @@ -27,11 +28,11 @@ class ListPartitionRouter(PartitionRouter): request_option (Optional[RequestOption]): The request option to configure the HTTP request """ - values: Union[str, List[str]] - cursor_field: Union[InterpolatedString, str] + values: str | list[str] + cursor_field: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - request_option: Optional[RequestOption] = None + request_option: RequestOption | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.values, str): @@ -48,36 +49,36 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.header, stream_slice) def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) @@ -91,7 +92,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: ] def _get_request_option( - self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, request_option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: if ( self.request_option @@ -101,19 +102,13 @@ def _get_request_option( slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString - else: - return {} - else: return {} + return {} def set_initial_state(self, stream_state: StreamState) -> None: - """ - ListPartitionRouter doesn't have parent streams - """ + """ListPartitionRouter doesn't have parent streams""" pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: - """ - ListPartitionRouter doesn't have parent streams - """ + def get_stream_state(self) -> Mapping[str, StreamState] | None: + """ListPartitionRouter doesn't have parent streams""" pass diff --git a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py index 3a9bc3ab..48373eb4 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py @@ -1,10 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Mapping, Optional from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer from airbyte_cdk.sources.types import StreamState @@ -12,8 +13,8 @@ @dataclass class PartitionRouter(StreamSlicer): - """ - Base class for partition routers. + """Base class for partition routers. + Methods: set_parent_state(stream_state): Set the state of the parent streams. get_parent_state(): Get the state of the parent streams. @@ -21,8 +22,7 @@ class PartitionRouter(StreamSlicer): @abstractmethod def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the state of the parent streams. + """Set the state of the parent streams. This method should only be implemented if the slicer is based on some parent stream and needs to read this stream incrementally using the state. @@ -30,7 +30,8 @@ def set_initial_state(self, stream_state: StreamState) -> None: Args: stream_state (StreamState): The state of the streams to be set. The expected format is a dictionary that includes 'parent_state' which is a dictionary of parent state names to their corresponding state. - Example: + + Example: { "parent_state": { "parent_stream_name_1": { ... }, @@ -41,9 +42,8 @@ def set_initial_state(self, stream_state: StreamState) -> None: """ @abstractmethod - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: - """ - Get the state of the parent streams. + def get_stream_state(self) -> Mapping[str, StreamState] | None: + """Get the state of the parent streams. This method should only be implemented if the slicer is based on some parent stream and needs to read this stream incrementally using the state. diff --git a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py index 32e6a353..d30a7655 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import StreamSlice, StreamState @@ -17,33 +19,33 @@ class SinglePartitionRouter(PartitionRouter): def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} @@ -51,13 +53,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: yield StreamSlice(partition={}, cursor_slice={}) def set_initial_state(self, stream_state: StreamState) -> None: - """ - SinglePartitionRouter doesn't have parent streams - """ + """SinglePartitionRouter doesn't have parent streams""" pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: - """ - SinglePartitionRouter doesn't have parent streams - """ + def get_stream_state(self) -> Mapping[str, StreamState] | None: + """SinglePartitionRouter doesn't have parent streams""" pass diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 4c761d08..c3d7b7cd 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import copy import logging +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any import dpath + from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -18,14 +22,14 @@ from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException + if TYPE_CHECKING: from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream @dataclass class ParentStreamConfig: - """ - Describes how to create a stream slice from a parent stream + """Describes how to create a stream slice from a parent stream stream: The stream to read records from parent_key: The key of the parent stream's records that will be the stream slice key @@ -35,15 +39,15 @@ class ParentStreamConfig: incremental_dependency (bool): Indicates if the parent stream should be read incrementally. """ - stream: "DeclarativeStream" # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream - parent_key: Union[InterpolatedString, str] - partition_field: Union[InterpolatedString, str] + stream: DeclarativeStream # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream + parent_key: InterpolatedString | str + partition_field: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = ( + extra_fields: list[list[str]] | list[list[InterpolatedString]] | None = ( None # List of field paths (arrays of strings) ) - request_option: Optional[RequestOption] = None + request_option: RequestOption | None = None incremental_dependency: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -61,15 +65,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @dataclass class SubstreamPartitionRouter(PartitionRouter): - """ - Partition router that iterates over the parent's stream records and emits slices + """Partition router that iterates over the parent's stream records and emits slices Will populate the state with `partition_field` and `parent_slice` so they can be accessed by other components Attributes: parent_stream_configs (List[ParentStreamConfig]): parent streams to iterate over and their config """ - parent_stream_configs: List[ParentStreamConfig] + parent_stream_configs: list[ParentStreamConfig] config: Config parameters: InitVar[Mapping[str, Any]] @@ -80,42 +83,42 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.header, stream_slice) def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) def _get_request_option( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: params = {} if stream_slice: @@ -137,8 +140,7 @@ def _get_request_option( return params def stream_slices(self) -> Iterable[StreamSlice]: - """ - Iterate over each parent stream's record and create a StreamSlice for each record. + """Iterate over each parent stream's record and create a StreamSlice for each record. For each stream, iterate over its stream_slices. For each stream slice, iterate over each record. @@ -210,10 +212,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: def _extract_extra_fields( self, parent_record: Mapping[str, Any] | AirbyteMessage, - extra_fields: Optional[List[List[str]]] = None, + extra_fields: list[list[str]] | None = None, ) -> Mapping[str, Any]: - """ - Extracts additional fields specified by their paths from the parent record. + """Extracts additional fields specified by their paths from the parent record. Args: parent_record (Mapping[str, Any]): The record from the parent stream to extract fields from. @@ -238,8 +239,7 @@ def _extract_extra_fields( return extracted_extra_fields def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the state of the parent streams. + """Set the state of the parent streams. If the `parent_state` key is missing from `stream_state`, migrate the child stream state to the parent stream's state format. This migration applies only to parent streams with incremental dependencies. @@ -306,9 +306,8 @@ def set_initial_state(self, stream_state: StreamState) -> None: if parent_config.incremental_dependency: parent_config.stream.state = parent_state.get(parent_config.stream.name, {}) - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: - """ - Get the state of the parent streams. + def get_stream_state(self) -> Mapping[str, StreamState] | None: + """Get the state of the parent streams. Returns: StreamState: The current state of the parent streams. diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py index d9213eb9..a3574860 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy from airbyte_cdk.sources.types import Config @@ -13,14 +16,13 @@ @dataclass class ConstantBackoffStrategy(BackoffStrategy): - """ - Backoff strategy with a constant backoff interval + """Backoff strategy with a constant backoff interval Attributes: backoff_time_in_seconds (float): time to backoff before retrying a retryable request. """ - backoff_time_in_seconds: Union[float, InterpolatedString, str] + backoff_time_in_seconds: float | InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config @@ -38,7 +40,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: return self.backoff_time_in_seconds.eval(self.config) # type: ignore # backoff_time_in_seconds is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py index b3a57675..80cc0aa4 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy from airbyte_cdk.sources.types import Config @@ -13,8 +16,7 @@ @dataclass class ExponentialBackoffStrategy(BackoffStrategy): - """ - Backoff strategy with an exponential backoff interval + """Backoff strategy with an exponential backoff interval Attributes: factor (float): multiplicative factor @@ -22,7 +24,7 @@ class ExponentialBackoffStrategy(BackoffStrategy): parameters: InitVar[Mapping[str, Any]] config: Config - factor: Union[float, InterpolatedString, str] = 5 + factor: float | InterpolatedString | str = 5 def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not isinstance(self.factor, InterpolatedString): @@ -38,7 +40,7 @@ def _retry_factor(self) -> float: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: return self._retry_factor * 2**attempt_count # type: ignore # factor is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py index 60103f34..1061b6b3 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py @@ -1,19 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import numbers from re import Pattern -from typing import Optional import requests def get_numeric_value_from_header( - response: requests.Response, header: str, regex: Optional[Pattern[str]] -) -> Optional[float]: - """ - Extract a header value from the response as a float + response: requests.Response, header: str, regex: Pattern[str] | None +) -> float | None: + """Extract a header value from the response as a float :param response: response the extract header value from :param header: Header to extract :param regex: optional regex to apply on the header to obtain the value @@ -28,13 +27,12 @@ def get_numeric_value_from_header( if match: header_value = match.group() return _as_float(header_value) - elif isinstance(header_value, numbers.Number): + if isinstance(header_value, numbers.Number): return float(header_value) # type: ignore[arg-type] - else: - return None + return None -def _as_float(s: str) -> Optional[float]: +def _as_float(s: str) -> float | None: try: return float(s) except ValueError: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py index 7672bd82..9ed0ddd1 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import re +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( @@ -21,8 +24,7 @@ @dataclass class WaitTimeFromHeaderBackoffStrategy(BackoffStrategy): - """ - Extract wait time from http header + """Extract wait time from http header Attributes: header (str): header to read wait time from @@ -30,11 +32,11 @@ class WaitTimeFromHeaderBackoffStrategy(BackoffStrategy): max_waiting_time_in_seconds: (Optional[float]): given the value extracted from the header is greater than this value, stop the stream """ - header: Union[InterpolatedString, str] + header: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config - regex: Optional[Union[InterpolatedString, str]] = None - max_waiting_time_in_seconds: Optional[float] = None + regex: InterpolatedString | str | None = None + max_waiting_time_in_seconds: float | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.regex = ( @@ -44,9 +46,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: header = self.header.eval(config=self.config) # type: ignore # header is always cast to an interpolated stream if self.regex: evaled_regex = self.regex.eval(self.config) # type: ignore # header is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py index 4aed7338..f2a90d74 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import numbers import re import time +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( get_numeric_value_from_header, @@ -21,8 +24,7 @@ @dataclass class WaitUntilTimeFromHeaderBackoffStrategy(BackoffStrategy): - """ - Extract time at which we can retry the request from response header + """Extract time at which we can retry the request from response header and wait for the difference between now and that time Attributes: @@ -31,11 +33,11 @@ class WaitUntilTimeFromHeaderBackoffStrategy(BackoffStrategy): regex (Optional[str]): optional regex to apply on the header to extract its value """ - header: Union[InterpolatedString, str] + header: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config - min_wait: Optional[Union[float, InterpolatedString, str]] = None - regex: Optional[Union[InterpolatedString, str]] = None + min_wait: float | InterpolatedString | str | None = None + regex: InterpolatedString | str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.header = InterpolatedString.create(self.header, parameters=parameters) @@ -47,9 +49,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: now = time.time() header = self.header.eval(self.config) # type: ignore # header is always cast to an interpolated string if self.regex: @@ -71,6 +73,6 @@ def backoff_time( return float(min_wait) if min_wait: return float(max(wait_time, min_wait)) - elif wait_time < 0: + if wait_time < 0: return None return wait_time diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategy.py index 7a44f7b9..14bd21be 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategy.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC from dataclasses import dataclass @@ -10,8 +11,7 @@ @dataclass class DecalarativeBackoffStrategy(BackoffStrategy, ABC): - """ - This interface exists to retain backwards compatability with connectors that reference the declarative BackoffStrategy. As part of the effort to promote common interfaces to the Python CDK, this now extends the Python CDK backoff strategy interface. + """This interface exists to retain backwards compatability with connectors that reference the declarative BackoffStrategy. As part of the effort to promote common interfaces to the Python CDK, this now extends the Python CDK backoff strategy interface. Backoff strategy defining how long to wait before retrying a request that resulted in an error. """ diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py index 717fcba6..37bb5e97 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, List, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( ErrorResolution, @@ -16,8 +19,7 @@ @dataclass class CompositeErrorHandler(ErrorHandler): - """ - Error handler that sequentially iterates over a list of `ErrorHandler`s + """Error handler that sequentially iterates over a list of `ErrorHandler`s Sample config chaining 2 different retriers: error_handler: @@ -39,7 +41,7 @@ class CompositeErrorHandler(ErrorHandler): error_handlers (List[ErrorHandler]): list of error handlers """ - error_handlers: List[ErrorHandler] + error_handlers: list[ErrorHandler] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -47,15 +49,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: raise ValueError("CompositeErrorHandler expects at least 1 underlying error handler") @property - def max_retries(self) -> Optional[int]: + def max_retries(self) -> int | None: return self.error_handlers[0].max_retries @property - def max_time(self) -> Optional[int]: + def max_time(self) -> int | None: return max([error_handler.max_time or 0 for error_handler in self.error_handlers]) def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] + self, response_or_exception: requests.Response | Exception | None ) -> ErrorResolution: matched_error_resolution = None for error_handler in self.error_handlers: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py index ad4a6261..ca9cfa83 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, MutableMapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import ( DefaultHttpResponseFilter, ) @@ -23,8 +26,7 @@ @dataclass class DefaultErrorHandler(ErrorHandler): - """ - Default error handler. + """Default error handler. By default, the handler will only use the `DEFAULT_ERROR_MAPPING` that is part of the Python CDK's `HttpStatusErrorHandler`. @@ -94,12 +96,12 @@ class DefaultErrorHandler(ErrorHandler): parameters: InitVar[Mapping[str, Any]] config: Config - response_filters: Optional[List[HttpResponseFilter]] = None - max_retries: Optional[int] = 5 + response_filters: list[HttpResponseFilter] | None = None + max_retries: int | None = 5 max_time: int = 60 * 10 _max_retries: int = field(init=False, repr=False, default=5) _max_time: int = field(init=False, repr=False, default=60 * 10) - backoff_strategies: Optional[List[BackoffStrategy]] = None + backoff_strategies: list[BackoffStrategy] | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not self.response_filters: @@ -108,7 +110,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._last_request_to_attempt_count: MutableMapping[requests.PreparedRequest, int] = {} def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] + self, response_or_exception: requests.Response | Exception | None ) -> ErrorResolution: if self.response_filters: for response_filter in self.response_filters: @@ -124,17 +126,15 @@ def interpret_response( default_reponse_filter = DefaultHttpResponseFilter(parameters={}, config=self.config) default_response_filter_resolution = default_reponse_filter.matches(response_or_exception) - return ( - default_response_filter_resolution - if default_response_filter_resolution - else create_fallback_error_resolution(response_or_exception) + return default_response_filter_resolution or create_fallback_error_resolution( + response_or_exception ) def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int = 0, - ) -> Optional[float]: + ) -> float | None: backoff = None if self.backoff_strategies: for backoff_strategy in self.backoff_strategies: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py index 395df5c9..61cbe3f9 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py @@ -1,10 +1,10 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # - -from typing import Optional, Union +from __future__ import annotations import requests + from airbyte_cdk.sources.declarative.requesters.error_handlers.http_response_filter import ( HttpResponseFilter, ) @@ -19,12 +19,12 @@ class DefaultHttpResponseFilter(HttpResponseFilter): def matches( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ErrorResolution]: + self, response_or_exception: requests.Response | Exception | None + ) -> ErrorResolution | None: default_mapped_error_resolution = None if isinstance(response_or_exception, (requests.Response, Exception)): - mapped_key: Union[int, type] = ( + mapped_key: int | type = ( response_or_exception.status_code if isinstance(response_or_exception, requests.Response) else response_or_exception.__class__ @@ -32,8 +32,6 @@ def matches( default_mapped_error_resolution = DEFAULT_ERROR_MAPPING.get(mapped_key) - return ( - default_mapped_error_resolution - if default_mapped_error_resolution - else create_fallback_error_resolution(response_or_exception) + return default_mapped_error_resolution or create_fallback_error_resolution( + response_or_exception ) diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/error_handler.py index a84747f9..a7feb82f 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/error_handler.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC from dataclasses import dataclass @@ -10,8 +11,7 @@ @dataclass class DeclarativeErrorHandler(ErrorHandler, ABC): - """ - This interface exists to retain backwards compatability with connectors that reference the declarative ErrorHandler. As part of the effort to promote common interfaces to the Python CDK, this now extends the Python CDK ErrorHandler interface. + """This interface exists to retain backwards compatability with connectors that reference the declarative ErrorHandler. As part of the effort to promote common interfaces to the Python CDK, this now extends the Python CDK ErrorHandler interface. `ErrorHandler` defines how to handle errors that occur during the request process, returning an ErrorResolution object that defines how to proceed. """ diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py index 366ad687..39727841 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Set, Union +from typing import Any import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean @@ -22,8 +25,7 @@ @dataclass class HttpResponseFilter: - """ - Filter to select a response based on its HTTP status code, error message or a predicate. + """Filter to select a response based on its HTTP status code, error message or a predicate. If a response matches the filter, the response action, failure_type, and error message are returned as an ErrorResolution object. For http_codes declared in the filter, the failure_type will default to `system_error`. To override default failure_type use configured failure_type with ResponseAction.FAIL. @@ -39,12 +41,12 @@ class HttpResponseFilter: config: Config parameters: InitVar[Mapping[str, Any]] - action: Optional[Union[ResponseAction, str]] = None - failure_type: Optional[Union[FailureType, str]] = None - http_codes: Optional[Set[int]] = None - error_message_contains: Optional[str] = None - predicate: Union[InterpolatedBoolean, str] = "" - error_message: Union[InterpolatedString, str] = "" + action: ResponseAction | str | None = None + failure_type: FailureType | str | None = None + http_codes: set[int] | None = None + error_message_contains: str | None = None + predicate: InterpolatedBoolean | str = "" + error_message: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.action is not None: @@ -56,7 +58,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: raise ValueError( "HttpResponseFilter requires a filter condition if an action is specified" ) - elif isinstance(self.action, str): + if isinstance(self.action, str): self.action = ResponseAction[self.action] self.http_codes = self.http_codes or set() if isinstance(self.predicate, str): @@ -69,8 +71,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.failure_type = FailureType[self.failure_type] def matches( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ErrorResolution]: + self, response_or_exception: requests.Response | Exception | None + ) -> ErrorResolution | None: filter_action = self._matches_filter(response_or_exception) mapped_key = ( response_or_exception.status_code @@ -117,15 +119,14 @@ def matches( return None def _match_default_error_mapping( - self, mapped_key: Union[int, type[Exception]] - ) -> Optional[ErrorResolution]: + self, mapped_key: int | type[Exception] + ) -> ErrorResolution | None: return DEFAULT_ERROR_MAPPING.get(mapped_key) def _matches_filter( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ResponseAction]: - """ - Apply the HTTP filter on the response and return the action to execute if it matches + self, response_or_exception: requests.Response | Exception | None + ) -> ResponseAction | None: + """Apply the HTTP filter on the response and return the action to execute if it matches :param response: The HTTP response to evaluate :return: The action to execute. None if the response does not match the filter """ @@ -144,9 +145,8 @@ def _safe_response_json(response: requests.Response) -> dict[str, Any]: except requests.exceptions.JSONDecodeError: return {} - def _create_error_message(self, response: requests.Response) -> Optional[str]: - """ - Construct an error message based on the specified message template of the filter. + def _create_error_message(self, response: requests.Response) -> str | None: + """Construct an error message based on the specified message template of the filter. :param response: The HTTP response which can be used during interpolation :return: The evaluated error message string to be emitted """ @@ -169,8 +169,5 @@ def _response_matches_predicate(self, response: requests.Response) -> bool: def _response_contains_error_message(self, response: requests.Response) -> bool: if not self.error_message_contains: return False - else: - error_message = self._error_message_parser.parse_response_error_message( - response=response - ) - return bool(error_message and self.error_message_contains in error_message) + error_message = self._error_message_parser.parse_response_error_message(response=response) + return bool(error_message and self.error_message_contains in error_message) diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index ff213068..4c604d1f 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -1,11 +1,16 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations + import logging import uuid +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any import requests +from requests import Response + from airbyte_cdk import AirbyteMessage from airbyte_cdk.logger import lazy_log from airbyte_cdk.models import FailureType, Type @@ -23,7 +28,7 @@ from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.utils import AirbyteTracedException -from requests import Response + LOGGER = logging.getLogger("airbyte") @@ -33,24 +38,23 @@ class AsyncHttpJobRepository(AsyncJobRepository): creation_requester: Requester polling_requester: Requester download_retriever: SimpleRetriever - abort_requester: Optional[Requester] - delete_requester: Optional[Requester] + abort_requester: Requester | None + delete_requester: Requester | None status_extractor: DpathExtractor status_mapping: Mapping[str, AsyncJobStatus] urls_extractor: DpathExtractor - job_timeout: Optional[timedelta] = None + job_timeout: timedelta | None = None record_extractor: RecordExtractor = field( init=False, repr=False, default_factory=lambda: ResponseToFileExtractor() ) def __post_init__(self) -> None: - self._create_job_response_by_id: Dict[str, Response] = {} - self._polling_job_response_by_id: Dict[str, Response] = {} + self._create_job_response_by_id: dict[str, Response] = {} + self._polling_job_response_by_id: dict[str, Response] = {} def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests.Response: - """ - Validates and retrieves the pooling response for a given stream slice. + """Validates and retrieves the pooling response for a given stream slice. Args: stream_slice (StreamSlice): The stream slice to send the pooling request for. @@ -61,8 +65,7 @@ def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests Raises: AirbyteTracedException: If the polling request returns an empty response. """ - - polling_response: Optional[requests.Response] = self.polling_requester.send_request( + polling_response: requests.Response | None = self.polling_requester.send_request( stream_slice=stream_slice ) if polling_response is None: @@ -73,8 +76,7 @@ def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests return polling_response def _get_validated_job_status(self, response: requests.Response) -> AsyncJobStatus: - """ - Validates the job status extracted from the API response. + """Validates the job status extracted from the API response. Args: response (requests.Response): The API response. @@ -85,7 +87,6 @@ def _get_validated_job_status(self, response: requests.Response) -> AsyncJobStat Raises: ValueError: If the API status is unknown. """ - api_status = next(iter(self.status_extractor.extract_records(response)), None) job_status = self.status_mapping.get(str(api_status), None) if job_status is None: @@ -96,8 +97,7 @@ def _get_validated_job_status(self, response: requests.Response) -> AsyncJobStat return job_status def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> requests.Response: - """ - Starts a job and validates the response. + """Starts a job and validates the response. Args: stream_slice (StreamSlice): The stream slice to be used for the job. @@ -108,8 +108,7 @@ def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> request Raises: AirbyteTracedException: If no response is received from the creation requester. """ - - response: Optional[requests.Response] = self.creation_requester.send_request( + response: requests.Response | None = self.creation_requester.send_request( stream_slice=stream_slice ) if not response: @@ -121,8 +120,7 @@ def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> request return response def start(self, stream_slice: StreamSlice) -> AsyncJob: - """ - Starts a job for the given stream slice. + """Starts a job for the given stream slice. Args: stream_slice (StreamSlice): The stream slice to start the job for. @@ -130,7 +128,6 @@ def start(self, stream_slice: StreamSlice) -> AsyncJob: Returns: AsyncJob: The asynchronous job object representing the started job. """ - response: requests.Response = self._start_job_and_validate_response(stream_slice) job_id: str = str(uuid.uuid4()) self._create_job_response_by_id[job_id] = response @@ -138,8 +135,7 @@ def start(self, stream_slice: StreamSlice) -> AsyncJob: return AsyncJob(api_job_id=job_id, job_parameters=stream_slice, timeout=self.job_timeout) def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None: - """ - Updates the status of multiple jobs. + """Updates the status of multiple jobs. Because we don't have interpolation on random fields, we have this hack which consist on using the stream_slice to allow for interpolation. We are looking at enabling interpolation on more field which would require a change to those three layers: @@ -174,8 +170,7 @@ def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None: self._polling_job_response_by_id[job.api_job_id()] = polling_response def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: - """ - Fetches records from the given job. + """Fetches records from the given job. Args: job (AsyncJob): The job to fetch records from. @@ -184,7 +179,6 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: Iterable[Mapping[str, Any]]: A generator that yields records as dictionaries. """ - for url in self.urls_extractor.extract_records( self._polling_job_response_by_id[job.api_job_id()] ): diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 51ece9f9..c1bf2cc4 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import os +from collections.abc import Callable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +from typing import Any from urllib.parse import urljoin import requests + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import ( DeclarativeAuthenticator, NoAuth, @@ -29,8 +32,7 @@ @dataclass class HttpRequester(Requester): - """ - Default implementation of a Requester + """Default implementation of a Requester Attributes: name (str): Name of the stream. Only used for request/response caching @@ -46,14 +48,14 @@ class HttpRequester(Requester): """ name: str - url_base: Union[InterpolatedString, str] - path: Union[InterpolatedString, str] + url_base: InterpolatedString | str + path: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - authenticator: Optional[DeclarativeAuthenticator] = None - http_method: Union[str, HttpMethod] = HttpMethod.GET - request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None - error_handler: Optional[ErrorHandler] = None + authenticator: DeclarativeAuthenticator | None = None + http_method: str | HttpMethod = HttpMethod.GET + request_options_provider: InterpolatedRequestOptionsProvider | None = None + error_handler: ErrorHandler | None = None disable_retries: bool = False message_repository: MessageRepository = NoopMessageRepository() use_cache: bool = False @@ -114,9 +116,9 @@ def get_url_base(self) -> str: def get_path( self, *, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> str: kwargs = { "stream_state": stream_state, @@ -132,9 +134,9 @@ def get_method(self) -> HttpMethod: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return self._request_options_provider.get_request_params( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -143,9 +145,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._request_options_provider.get_request_headers( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -155,10 +157,10 @@ def get_request_headers( def get_request_body_data( # type: ignore self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return ( self._request_options_provider.get_request_body_data( stream_state=stream_state, @@ -172,10 +174,10 @@ def get_request_body_data( # type: ignore def get_request_body_json( # type: ignore self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: return self._request_options_provider.get_request_body_json( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) @@ -186,15 +188,14 @@ def logger(self) -> logging.Logger: def _get_request_options( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - requester_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - auth_options_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - extra_options: Optional[Union[Mapping[str, Any], str]] = None, - ) -> Union[Mapping[str, Any], str]: - """ - Get the request_option from the requester, the authenticator and extra_options passed in. + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + requester_method: Callable[..., Mapping[str, Any] | str | None], + auth_options_method: Callable[..., Mapping[str, Any] | str | None], + extra_options: Mapping[str, Any] | str | None = None, + ) -> Mapping[str, Any] | str: + """Get the request_option from the requester, the authenticator and extra_options passed in. Raise a ValueError if there's a key collision Returned merged mapping otherwise """ @@ -212,13 +213,12 @@ def _get_request_options( def _request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - extra_headers: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + extra_headers: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies request headers. + """Specifies request headers. Authentication headers will overwrite any overlapping headers returned from this method. """ headers = self._get_request_options( @@ -235,13 +235,12 @@ def _request_headers( def _request_params( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_params: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_params: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. + """Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. E.g: you might want to define query parameters for paging if next_page_token is not None. """ @@ -266,13 +265,12 @@ def _request_params( def _request_body_data( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_body_data: Optional[Union[Mapping[str, Any], str]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: - """ - Specifies how to populate the body of the request with a non-JSON payload. + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_body_data: Mapping[str, Any] | str | None = None, + ) -> Mapping[str, Any] | str | None: + """Specifies how to populate the body of the request with a non-JSON payload. If returns a ready text that it will be sent as is. If returns a dict that it will be converted to a urlencoded form. @@ -292,13 +290,12 @@ def _request_body_data( def _request_body_json( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_body_json: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: - """ - Specifies how to populate the body of the request with a JSON payload. + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_body_json: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: + """Specifies how to populate the body of the request with a JSON payload. At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. """ @@ -321,16 +318,16 @@ def _join_url(cls, url_base: str, path: str) -> str: def send_request( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - path: Optional[str] = None, - request_headers: Optional[Mapping[str, Any]] = None, - request_params: Optional[Mapping[str, Any]] = None, - request_body_data: Optional[Union[Mapping[str, Any], str]] = None, - request_body_json: Optional[Mapping[str, Any]] = None, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - ) -> Optional[requests.Response]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + path: str | None = None, + request_headers: Mapping[str, Any] | None = None, + request_params: Mapping[str, Any] | None = None, + request_body_data: Mapping[str, Any] | str | None = None, + request_body_json: Mapping[str, Any] | None = None, + log_formatter: Callable[[requests.Response], Any] | None = None, + ) -> requests.Response | None: request, response = self._http_client.send_request( http_method=self.get_method().value, url=self._join_url( diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index e26f32de..bb77ab92 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.decoders import ( Decoder, JsonDecoder, @@ -26,8 +29,7 @@ @dataclass class DefaultPaginator(Paginator): - """ - Default paginator to request pages of results with a fixed size until the pagination strategy no longer returns a next_page_token + """Default paginator to request pages of results with a fixed size until the pagination strategy no longer returns a next_page_token Examples: 1. @@ -96,13 +98,13 @@ class DefaultPaginator(Paginator): pagination_strategy: PaginationStrategy config: Config - url_base: Union[InterpolatedString, str] + url_base: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) ) - page_size_option: Optional[RequestOption] = None - page_token_option: Optional[Union[RequestPath, RequestOption]] = None + page_size_option: RequestOption | None = None + page_token_option: RequestPath | RequestOption | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.page_size_option and not self.pagination_strategy.get_page_size(): @@ -111,20 +113,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) if isinstance(self.url_base, str): self.url_base = InterpolatedString(string=self.url_base, parameters=parameters) - self._token: Optional[Any] = self.pagination_strategy.initial_token + self._token: Any | None = self.pagination_strategy.initial_token def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Mapping[str, Any]]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Mapping[str, Any] | None: self._token = self.pagination_strategy.next_page_token( response, last_page_size, last_record ) if self._token: return {"next_page_token": self._token} - else: - return None + return None - def path(self) -> Optional[str]: + def path(self) -> str | None: if ( self._token and self.page_token_option @@ -132,46 +133,45 @@ def path(self) -> Optional[str]: ): # Replace url base to only return the path return str(self._token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__ - else: - return None + return None def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, str]: return self._get_request_options(RequestOptionType.header) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_data) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json) - def reset(self, reset_value: Optional[Any] = None) -> None: + def reset(self, reset_value: Any | None = None) -> None: if reset_value: self.pagination_strategy.reset(reset_value=reset_value) else: @@ -200,8 +200,7 @@ def _get_request_options(self, option_type: RequestOptionType) -> MutableMapping class PaginatorTestReadDecorator(Paginator): - """ - In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of + """In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of pages that are queried throughout a read command. """ @@ -217,23 +216,23 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Mapping[str, Any]]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Mapping[str, Any] | None: if self._page_count >= self._maximum_number_of_pages: return None self._page_count += 1 return self._decorated.next_page_token(response, last_page_size, last_record) - def path(self) -> Optional[str]: + def path(self) -> str | None: return self._decorated.path() def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._decorated.get_request_params( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -242,9 +241,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, str]: return self._decorated.get_request_headers( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -253,10 +252,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._decorated.get_request_body_data( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) @@ -264,14 +263,14 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._decorated.get_request_body_json( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) - def reset(self, reset_value: Optional[Any] = None) -> None: + def reset(self, reset_value: Any | None = None) -> None: self._decorated.reset() self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py index db4eb0ed..ab27d811 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py @@ -1,67 +1,68 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator from airbyte_cdk.sources.types import Record, StreamSlice, StreamState @dataclass class NoPagination(Paginator): - """ - Pagination implementation that never returns a next page. - """ + """Pagination implementation that never returns a next page.""" parameters: InitVar[Mapping[str, Any]] - def path(self) -> Optional[str]: + def path(self) -> str | None: return None def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, str]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, response: requests.Response, last_page_size: int, last_record: Record | None ) -> Mapping[str, Any]: return {} - def reset(self, reset_value: Optional[Any] = None) -> None: + def reset(self, reset_value: Any | None = None) -> None: # No state to reset pass diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py index 1bf17d1d..d326405f 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping, Optional +from typing import Any import requests + from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, ) @@ -15,25 +18,21 @@ @dataclass class Paginator(ABC, RequestOptionsProvider): - """ - Defines the token to use to fetch the next page of records from the API. + """Defines the token to use to fetch the next page of records from the API. If needed, the Paginator will set request options to be set on the HTTP request to fetch the next page of records. If the next_page_token is the path to the next page of records, then it should be accessed through the `path` method """ @abstractmethod - def reset(self, reset_value: Optional[Any] = None) -> None: - """ - Reset the pagination's inner state - """ + def reset(self, reset_value: Any | None = None) -> None: + """Reset the pagination's inner state""" @abstractmethod def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Mapping[str, Any]]: - """ - Returns the next_page_token to use to fetch the next page of records. + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Mapping[str, Any] | None: + """Returns the next_page_token to use to fetch the next page of records. :param response: the response to process :param last_page_size: the number of records read from the response @@ -43,9 +42,8 @@ def next_page_token( pass @abstractmethod - def path(self) -> Optional[str]: - """ - Returns the URL path to hit to fetch the next page of records + def path(self) -> str | None: + """Returns the URL path to hit to fetch the next page of records e.g: if you wanted to hit https://myapi.com/v1/some_entity then this will return "some_entity" diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py index a53a044b..0c3f04d7 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.decoders import ( Decoder, JsonDecoder, @@ -21,8 +24,7 @@ @dataclass class CursorPaginationStrategy(PaginationStrategy): - """ - Pagination strategy that evaluates an interpolated string to define the next page token + """Pagination strategy that evaluates an interpolated string to define the next page token Attributes: page_size (Optional[int]): the number of records to request @@ -32,11 +34,11 @@ class CursorPaginationStrategy(PaginationStrategy): decoder (Decoder): decoder to decode the response """ - cursor_value: Union[InterpolatedString, str] + cursor_value: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - page_size: Optional[int] = None - stop_condition: Optional[Union[InterpolatedBoolean, str]] = None + page_size: int | None = None + stop_condition: InterpolatedBoolean | str | None = None decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) ) @@ -48,24 +50,24 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else: self._cursor_value = self.cursor_value if isinstance(self.stop_condition, str): - self._stop_condition: Optional[InterpolatedBoolean] = InterpolatedBoolean( + self._stop_condition: InterpolatedBoolean | None = InterpolatedBoolean( condition=self.stop_condition, parameters=parameters ) else: self._stop_condition = self.stop_condition @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: return self._initial_cursor def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Any]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Any | None: decoded_response = next(self.decoder.decode(response)) # The default way that link is presented in requests.Response is a string of various links (last, next, etc). This # is not indexable or useful for parsing the cursor, so we replace it with the link dictionary from response.links - headers: Dict[str, Any] = dict(response.headers) + headers: dict[str, Any] = dict(response.headers) headers["link"] = response.links if self._stop_condition: should_stop = self._stop_condition.eval( @@ -84,10 +86,10 @@ def next_page_token( last_record=last_record, last_page_size=last_page_size, ) - return token if token else None + return token or None - def reset(self, reset_value: Optional[Any] = None) -> None: + def reset(self, reset_value: Any | None = None) -> None: self._initial_cursor = reset_value - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self.page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py index 9f24b961..02b9e154 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.decoders import ( Decoder, JsonDecoder, @@ -20,8 +23,7 @@ @dataclass class OffsetIncrement(PaginationStrategy): - """ - Pagination strategy that returns the number of records reads so far and returns it as the next page token + """Pagination strategy that returns the number of records reads so far and returns it as the next page token Examples: # page_size to be a constant integer value pagination_strategy: @@ -43,7 +45,7 @@ class OffsetIncrement(PaginationStrategy): """ config: Config - page_size: Optional[Union[str, int]] + page_size: str | int | None parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) @@ -54,21 +56,21 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._offset = 0 page_size = str(self.page_size) if isinstance(self.page_size, int) else self.page_size if page_size: - self._page_size: Optional[InterpolatedString] = InterpolatedString( + self._page_size: InterpolatedString | None = InterpolatedString( page_size, parameters=parameters ) else: self._page_size = None @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: if self.inject_on_first_request: return self._offset return None def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Any]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Any | None: decoded_response = next(self.decoder.decode(response)) # Stop paginating when there are fewer records than the page size or the current page has no records @@ -77,23 +79,20 @@ def next_page_token( and last_page_size < self._page_size.eval(self.config, response=decoded_response) ) or last_page_size == 0: return None - else: - self._offset += last_page_size - return self._offset + self._offset += last_page_size + return self._offset - def reset(self, reset_value: Optional[Any] = 0) -> None: + def reset(self, reset_value: Any | None = 0) -> None: if not isinstance(reset_value, int): raise ValueError( f"Reset value {reset_value} for OffsetIncrement pagination strategy was not an integer" ) - else: - self._offset = reset_value + self._offset = reset_value - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: if self._page_size: page_size = self._page_size.eval(self.config) if not isinstance(page_size, int): raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") return page_size - else: - return None + return None diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py index 1ce0a1c8..5f860574 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( PaginationStrategy, @@ -15,8 +18,7 @@ @dataclass class PageIncrement(PaginationStrategy): - """ - Pagination strategy that returns the number of pages reads so far and returns it as the next page token + """Pagination strategy that returns the number of pages reads so far and returns it as the next page token Attributes: page_size (int): the number of records to request @@ -24,7 +26,7 @@ class PageIncrement(PaginationStrategy): """ config: Config - page_size: Optional[Union[str, int]] + page_size: str | int | None parameters: InitVar[Mapping[str, Any]] start_from_page: int = 0 inject_on_first_request: bool = False @@ -40,22 +42,21 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._page_size = page_size @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: if self.inject_on_first_request: return self._page return None def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Any]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Any | None: # Stop paginating when there are fewer records than the page size or the current page has no records if (self._page_size and last_page_size < self._page_size) or last_page_size == 0: return None - else: - self._page += 1 - return self._page + self._page += 1 + return self._page - def reset(self, reset_value: Optional[Any] = None) -> None: + def reset(self, reset_value: Any | None = None) -> None: if reset_value is None: self._page = self.start_from_page elif not isinstance(reset_value, int): @@ -65,5 +66,5 @@ def reset(self, reset_value: Optional[Any] = None) -> None: else: self._page = reset_value - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self._page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py index 0b350d33..03afbbe5 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py @@ -1,34 +1,31 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import requests + from airbyte_cdk.sources.types import Record @dataclass class PaginationStrategy: - """ - Defines how to get the next page token - """ + """Defines how to get the next page token""" @property @abstractmethod - def initial_token(self) -> Optional[Any]: - """ - Return the initial value of the token - """ + def initial_token(self) -> Any | None: + """Return the initial value of the token""" @abstractmethod def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Any]: - """ - :param response: response to process + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Any | None: + """:param response: response to process :param last_page_size: the number of records read from the response :param last_record: the last record extracted from the response :return: next page token. Returns None if there are no more pages to fetch @@ -36,13 +33,9 @@ def next_page_token( pass @abstractmethod - def reset(self, reset_value: Optional[Any] = None) -> None: - """ - Reset the pagination's inner state - """ + def reset(self, reset_value: Any | None = None) -> None: + """Reset the pagination's inner state""" @abstractmethod - def get_page_size(self) -> Optional[int]: - """ - :return: page size: The number of records to fetch in a page. Returns None if unspecified - """ + def get_page_size(self) -> int | None: + """:return: page size: The number of records to fetch in a page. Returns None if unspecified""" diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py index 3f322aa9..41b5ce55 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any import requests + from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( PaginationStrategy, @@ -16,12 +18,11 @@ class PaginationStopCondition(ABC): @abstractmethod def is_met(self, record: Record) -> bool: - """ - Given a condition is met, the pagination will stop + """Given a condition is met, the pagination will stop :param record: a record used to evaluate the condition """ - raise NotImplementedError() + raise NotImplementedError class CursorStopCondition(PaginationStopCondition): @@ -38,8 +39,8 @@ def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStop self._stop_condition = stop_condition def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Optional[Any]: + self, response: requests.Response, last_page_size: int, last_record: Record | None + ) -> Any | None: # We evaluate in reverse order because the assumption is that most of the APIs using data feed structure will return records in # descending order. In terms of performance/memory, we return the records lazily if last_record and self._stop_condition.is_met(last_record): @@ -49,9 +50,9 @@ def next_page_token( def reset(self) -> None: self._delegate.reset() - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self._delegate.get_page_size() @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: return self._delegate.initial_token diff --git a/airbyte_cdk/sources/declarative/requesters/request_option.py b/airbyte_cdk/sources/declarative/requesters/request_option.py index d13d2056..062f3a40 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_option.py +++ b/airbyte_cdk/sources/declarative/requesters/request_option.py @@ -1,18 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass from enum import Enum -from typing import Any, Mapping, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString class RequestOptionType(Enum): - """ - Describes where to set a value on a request - """ + """Describes where to set a value on a request""" request_parameter = "request_parameter" header = "header" @@ -22,15 +22,14 @@ class RequestOptionType(Enum): @dataclass class RequestOption: - """ - Describes an option to set on a request + """Describes an option to set on a request Attributes: field_name (str): Describes the name of the parameter to inject inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter """ - field_name: Union[InterpolatedString, str] + field_name: InterpolatedString | str inject_into: RequestOptionType parameters: InitVar[Mapping[str, Any]] diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py index 5ce7c9a3..277ece70 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py @@ -1,9 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.request_option import ( @@ -18,17 +20,16 @@ @dataclass class DatetimeBasedRequestOptionsProvider(RequestOptionsProvider): - """ - Request options provider that extracts fields from the stream_slice and injects them into the respective location in the + """Request options provider that extracts fields from the stream_slice and injects them into the respective location in the outbound request being made """ config: Config parameters: InitVar[Mapping[str, Any]] - start_time_option: Optional[RequestOption] = None - end_time_option: Optional[RequestOption] = None - partition_field_start: Optional[str] = None - partition_field_end: Optional[str] = None + start_time_option: RequestOption | None = None + end_time_option: RequestOption | None = None + partition_field_start: str | None = None + partition_field_end: str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._partition_field_start = InterpolatedString.create( @@ -41,41 +42,41 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.header, stream_slice) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._get_request_options(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, stream_slice) def _get_request_options( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py index 449da977..5f73b01c 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py @@ -1,9 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, @@ -13,8 +15,7 @@ @dataclass class DefaultRequestOptionsProvider(RequestOptionsProvider): - """ - Request options provider that extracts fields from the stream_slice and injects them into the respective location in the + """Request options provider that extracts fields from the stream_slice and injects them into the respective location in the outbound request being made """ @@ -26,35 +27,35 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return {} diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py index 6403417c..d0c22d63 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import ( InterpolatedNestedMapping, @@ -15,19 +17,15 @@ @dataclass class InterpolatedNestedRequestInputProvider: - """ - Helper class that generically performs string interpolation on a provided deeply nested dictionary or string input - """ + """Helper class that generically performs string interpolation on a provided deeply nested dictionary or string input""" parameters: InitVar[Mapping[str, Any]] - request_inputs: Optional[Union[str, NestedMapping]] = field(default=None) + request_inputs: str | NestedMapping | None = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedNestedMapping]] = field( - init=False, repr=False, default=None - ) - _request_inputs: Optional[Union[str, NestedMapping]] = field( + _interpolator: InterpolatedString | InterpolatedNestedMapping | None = field( init=False, repr=False, default=None ) + _request_inputs: str | NestedMapping | None = field(init=False, repr=False, default=None) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} @@ -42,12 +40,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def eval_request_inputs( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Returns the request inputs to set on an outgoing HTTP request + """Returns the request inputs to set on an outgoing HTTP request :param stream_state: The stream state :param stream_slice: The stream slice diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py index 0278df35..7a2e50c5 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Tuple, Type, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -12,19 +14,15 @@ @dataclass class InterpolatedRequestInputProvider: - """ - Helper class that generically performs string interpolation on the provided dictionary or string input - """ + """Helper class that generically performs string interpolation on the provided dictionary or string input""" parameters: InitVar[Mapping[str, Any]] - request_inputs: Optional[Union[str, Mapping[str, str]]] = field(default=None) + request_inputs: str | Mapping[str, str] | None = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedMapping]] = field( - init=False, repr=False, default=None - ) - _request_inputs: Optional[Union[str, Mapping[str, str]]] = field( + _interpolator: InterpolatedString | InterpolatedMapping | None = field( init=False, repr=False, default=None ) + _request_inputs: str | Mapping[str, str] | None = field(init=False, repr=False, default=None) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} @@ -37,14 +35,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def eval_request_inputs( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - valid_key_types: Optional[Tuple[Type[Any]]] = None, - valid_value_types: Optional[Tuple[Type[Any], ...]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + valid_key_types: tuple[type[Any]] | None = None, + valid_value_types: tuple[type[Any], ...] | None = None, ) -> Mapping[str, Any]: - """ - Returns the request inputs to set on an outgoing HTTP request + """Returns the request inputs to set on an outgoing HTTP request :param stream_state: The stream state :param stream_slice: The stream slice diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py index bd8cfc17..00242446 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py @@ -1,9 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any, Union + +from deprecated import deprecated from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import NestedMapping from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_nested_request_input_provider import ( @@ -17,7 +21,7 @@ ) from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.types import Config, StreamSlice, StreamState -from deprecated import deprecated + RequestInput = Union[str, Mapping[str, str]] ValidRequestTypes = (str, list) @@ -25,8 +29,7 @@ @dataclass class InterpolatedRequestOptionsProvider(RequestOptionsProvider): - """ - Defines the request options to set on an outgoing HTTP request by evaluating `InterpolatedMapping`s + """Defines the request options to set on an outgoing HTTP request by evaluating `InterpolatedMapping`s Attributes: config (Config): The user-provided configuration as specified by the source's spec @@ -38,10 +41,10 @@ class InterpolatedRequestOptionsProvider(RequestOptionsProvider): parameters: InitVar[Mapping[str, Any]] config: Config = field(default_factory=dict) - request_parameters: Optional[RequestInput] = None - request_headers: Optional[RequestInput] = None - request_body_data: Optional[RequestInput] = None - request_body_json: Optional[NestedMapping] = None + request_parameters: RequestInput | None = None + request_headers: RequestInput | None = None + request_body_data: RequestInput | None = None + request_body_json: NestedMapping | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.request_parameters is None: @@ -74,9 +77,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: interpolated_value = self._parameter_interpolator.eval_request_inputs( stream_state, @@ -92,9 +95,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._headers_interpolator.eval_request_inputs( stream_state, stream_slice, next_page_token @@ -103,10 +106,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._body_data_interpolator.eval_request_inputs( stream_state, stream_slice, @@ -118,9 +121,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._body_json_interpolator.eval_request_inputs( stream_state, stream_slice, next_page_token @@ -131,12 +134,10 @@ def get_request_body_json( category=ExperimentalClassWarning, ) def request_options_contain_stream_state(self) -> bool: - """ - Temporary helper method used as we move low-code streams to the concurrent framework. This method determines if + """Temporary helper method used as we move low-code streams to the concurrent framework. This method determines if the InterpolatedRequestOptionsProvider has is a dependency on a non-thread safe interpolation context such as stream_state. """ - return ( self._check_if_interpolation_uses_stream_state(self.request_parameters) or self._check_if_interpolation_uses_stream_state(self.request_headers) @@ -146,18 +147,17 @@ def request_options_contain_stream_state(self) -> bool: @staticmethod def _check_if_interpolation_uses_stream_state( - request_input: Optional[Union[RequestInput, NestedMapping]], + request_input: RequestInput | NestedMapping | None, ) -> bool: if not request_input: return False - elif isinstance(request_input, str): + if isinstance(request_input, str): return "stream_state" in request_input - else: - for key, val in request_input.items(): - # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case - # of a NestedMapping where the value is a string. - # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way - # in our code - if "stream_state" in key or (isinstance(val, str) and "stream_state" in val): - return True + for key, val in request_input.items(): + # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case + # of a NestedMapping where the value is a string. + # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way + # in our code + if "stream_state" in key or (isinstance(val, str) and "stream_state" in val): + return True return False diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py index f0a94ecb..5cb38867 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py @@ -1,18 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.types import StreamSlice, StreamState @dataclass class RequestOptionsProvider: - """ - Defines the request options to set on an outgoing HTTP request + """Defines the request options to set on an outgoing HTTP request Options can be passed by - request parameter @@ -25,12 +26,11 @@ class RequestOptionsProvider: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. + """Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. E.g: you might want to define query parameters for paging if next_page_token is not None. """ @@ -40,9 +40,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.""" @@ -50,12 +50,11 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: - """ - Specifies how to populate the body of the request with a non-JSON payload. + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: + """Specifies how to populate the body of the request with a non-JSON payload. If returns a ready text that it will be sent as is. If returns a dict that it will be converted to a urlencoded form. @@ -68,12 +67,11 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies how to populate the body of the request with a JSON payload. + """Specifies how to populate the body of the request with a JSON payload. At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. """ diff --git a/airbyte_cdk/sources/declarative/requesters/request_path.py b/airbyte_cdk/sources/declarative/requesters/request_path.py index 378ea622..b6792506 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_path.py +++ b/airbyte_cdk/sources/declarative/requesters/request_path.py @@ -1,15 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping +from typing import Any @dataclass class RequestPath: - """ - Describes that a component value should be inserted into the path - """ + """Describes that a component value should be inserted into the path""" parameters: InitVar[Mapping[str, Any]] diff --git a/airbyte_cdk/sources/declarative/requesters/requester.py b/airbyte_cdk/sources/declarative/requesters/requester.py index 19003a83..3a50a6be 100644 --- a/airbyte_cdk/sources/declarative/requesters/requester.py +++ b/airbyte_cdk/sources/declarative/requesters/requester.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Callable, Mapping, MutableMapping from enum import Enum -from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +from typing import Any import requests + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, @@ -15,9 +18,7 @@ class HttpMethod(Enum): - """ - Http Method to use when submitting an outgoing HTTP request - """ + """Http Method to use when submitting an outgoing HTTP request""" DELETE = "DELETE" GET = "GET" @@ -28,45 +29,36 @@ class HttpMethod(Enum): class Requester(RequestOptionsProvider): @abstractmethod def get_authenticator(self) -> DeclarativeAuthenticator: - """ - Specifies the authenticator to use when submitting requests - """ + """Specifies the authenticator to use when submitting requests""" pass @abstractmethod def get_url_base(self) -> str: - """ - :return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" - """ + """:return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" """ @abstractmethod def get_path( self, *, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> str: - """ - Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" - """ + """Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" """ @abstractmethod def get_method(self) -> HttpMethod: - """ - Specifies the HTTP method to use - """ + """Specifies the HTTP method to use""" @abstractmethod def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: - """ - Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. + """Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. E.g: you might want to define query parameters for paging if next_page_token is not None. """ @@ -75,24 +67,21 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. - """ + """Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.""" @abstractmethod def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: - """ - Specifies how to populate the body of the request with a non-JSON payload. + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: + """Specifies how to populate the body of the request with a non-JSON payload. If returns a ready text that it will be sent as is. If returns a dict that it will be converted to a urlencoded form. @@ -105,12 +94,11 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies how to populate the body of the request with a JSON payload. + """Specifies how to populate the body of the request with a JSON payload. At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. """ @@ -118,18 +106,17 @@ def get_request_body_json( @abstractmethod def send_request( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - path: Optional[str] = None, - request_headers: Optional[Mapping[str, Any]] = None, - request_params: Optional[Mapping[str, Any]] = None, - request_body_data: Optional[Union[Mapping[str, Any], str]] = None, - request_body_json: Optional[Mapping[str, Any]] = None, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - ) -> Optional[requests.Response]: - """ - Sends a request and returns the response. Might return no response if the error handler chooses to ignore the response or throw an exception in case of an error. + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + path: str | None = None, + request_headers: Mapping[str, Any] | None = None, + request_params: Mapping[str, Any] | None = None, + request_body_data: Mapping[str, Any] | str | None = None, + request_body_json: Mapping[str, Any] | None = None, + log_formatter: Callable[[requests.Response], Any] | None = None, + ) -> requests.Response | None: + """Sends a request and returns the response. Might return no response if the error handler chooses to ignore the response or throw an exception in case of an error. If path is set, the path configured on the requester itself is ignored. If header, params and body are set, they are merged with the ones configured on the requester itself. diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index f3902dfc..60d5566f 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -1,8 +1,11 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations - +from collections.abc import Callable, Iterable, Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Iterable, Mapping, Optional +from typing import Any + +from deprecated.classic import deprecated from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( @@ -17,7 +20,6 @@ from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.types import Config, StreamSlice, StreamState from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from deprecated.classic import deprecated @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) @@ -33,21 +35,17 @@ class AsyncRetriever(Retriever): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._job_orchestrator_factory = self.job_orchestrator_factory - self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None + self.__job_orchestrator: AsyncJobOrchestrator | None = None self._parameters = parameters @property def state(self) -> StreamState: - """ - As a first iteration for sendgrid, there is no state to be managed - """ + """As a first iteration for sendgrid, there is no state to be managed""" return {} @state.setter def state(self, value: StreamState) -> None: - """ - As a first iteration for sendgrid, there is no state to be managed - """ + """As a first iteration for sendgrid, there is no state to be managed""" pass @property @@ -62,20 +60,17 @@ def _job_orchestrator(self) -> AsyncJobOrchestrator: return self.__job_orchestrator def _get_stream_state(self) -> StreamState: - """ - Gets the current state of the stream. + """Gets the current state of the stream. Returns: StreamState: Mapping[str, Any] """ - return self.state def _validate_and_get_stream_slice_partition( - self, stream_slice: Optional[StreamSlice] = None + self, stream_slice: StreamSlice | None = None ) -> AsyncPartition: - """ - Validates the stream_slice argument and returns the partition from it. + """Validates the stream_slice argument and returns the partition from it. Args: stream_slice (Optional[StreamSlice]): The stream slice to validate and extract the partition from. @@ -94,7 +89,7 @@ def _validate_and_get_stream_slice_partition( ) return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: + def stream_slices(self) -> Iterable[StreamSlice | None]: slices = self.stream_slicer.stream_slices() self.__job_orchestrator = self._job_orchestrator_factory(slices) @@ -108,7 +103,7 @@ def stream_slices(self) -> Iterable[Optional[StreamSlice]]: def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: stream_state: StreamState = self._get_stream_state() partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice) diff --git a/airbyte_cdk/sources/declarative/retrievers/retriever.py b/airbyte_cdk/sources/declarative/retrievers/retriever.py index 155de578..7c3b8c65 100644 --- a/airbyte_cdk/sources/declarative/retrievers/retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/retriever.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice from airbyte_cdk.sources.streams.core import StreamData @@ -11,18 +13,15 @@ class Retriever: - """ - Responsible for fetching a stream's records from an HTTP API source. - """ + """Responsible for fetching a stream's records from an HTTP API source.""" @abstractmethod def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: - """ - Fetch a stream's records from an HTTP API source + """Fetch a stream's records from an HTTP API source :param records_schema: json schema to describe record :param stream_slice: The stream slice to read data for @@ -30,7 +29,7 @@ def read_records( """ @abstractmethod - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: + def stream_slices(self) -> Iterable[StreamSlice | None]: """Returns the stream slices""" @property diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 530cf5f5..348860ad 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -1,25 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json +from collections.abc import Callable, Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field from functools import partial from itertools import islice from typing import ( Any, - Callable, - Iterable, - List, - Mapping, - MutableMapping, - Optional, - Set, - Tuple, - Union, ) import requests + from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector from airbyte_cdk.sources.declarative.incremental import ResumableFullRefreshCursor @@ -42,13 +36,13 @@ from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils.mapping_helpers import combine_mappings + FULL_REFRESH_SYNC_COMPLETE_KEY = "__ab_full_refresh_sync_complete" @dataclass class SimpleRetriever(Retriever): - """ - Retrieves records by synchronously sending requests to fetch records. + """Retrieves records by synchronously sending requests to fetch records. The retriever acts as an orchestrator between the requester, the record selector, the paginator, and the stream slicer. @@ -74,24 +68,24 @@ class SimpleRetriever(Retriever): config: Config parameters: InitVar[Mapping[str, Any]] name: str - _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="") - primary_key: Optional[Union[str, List[str], List[List[str]]]] + _name: InterpolatedString | str = field(init=False, repr=False, default="") + primary_key: str | list[str] | list[list[str]] | None _primary_key: str = field(init=False, repr=False, default="") - paginator: Optional[Paginator] = None + paginator: Paginator | None = None stream_slicer: StreamSlicer = field( default_factory=lambda: SinglePartitionRouter(parameters={}) ) request_option_provider: RequestOptionsProvider = field( default_factory=lambda: DefaultRequestOptionsProvider(parameters={}) ) - cursor: Optional[DeclarativeCursor] = None + cursor: DeclarativeCursor | None = None ignore_stream_slicer_parameters_on_paginated_requests: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._paginator = self.paginator or NoPagination(parameters=parameters) - self._last_response: Optional[requests.Response] = None + self._last_response: requests.Response | None = None self._last_page_size: int = 0 - self._last_record: Optional[Record] = None + self._last_record: Record | None = None self._parameters = parameters self._name = ( InterpolatedString(self._name, parameters=parameters) @@ -105,9 +99,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @property # type: ignore def name(self) -> str: - """ - :return: Stream name - """ + """:return: Stream name""" return ( str(self._name.eval(self.config)) if isinstance(self._name, InterpolatedString) @@ -120,10 +112,9 @@ def name(self, value: str) -> None: self._name = value def _get_mapping( - self, method: Callable[..., Optional[Union[Mapping[str, Any], str]]], **kwargs: Any - ) -> Tuple[Union[Mapping[str, Any], str], Set[str]]: - """ - Get mapping from the provided method, and get the keys of the mapping. + self, method: Callable[..., Mapping[str, Any] | str | None], **kwargs: Any + ) -> tuple[Mapping[str, Any] | str, set[str]]: + """Get mapping from the provided method, and get the keys of the mapping. If the method returns a string, it will return the string and an empty set. If the method returns a dict, it will return the dict and its keys. """ @@ -133,14 +124,13 @@ def _get_mapping( def _get_request_options( self, - stream_state: Optional[StreamData], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - paginator_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - stream_slicer_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - ) -> Union[Mapping[str, Any], str]: - """ - Get the request_option from the paginator and the stream slicer. + stream_state: StreamData | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + paginator_method: Callable[..., Mapping[str, Any] | str | None], + stream_slicer_method: Callable[..., Mapping[str, Any] | str | None], + ) -> Mapping[str, Any] | str: + """Get the request_option from the paginator and the stream slicer. Raise a ValueError if there's a key collision Returned merged mapping otherwise """ @@ -164,12 +154,11 @@ def _get_request_options( def _request_headers( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies request headers. + """Specifies request headers. Authentication headers will overwrite any overlapping headers returned from this method. """ headers = self._get_request_options( @@ -185,12 +174,11 @@ def _request_headers( def _request_params( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. + """Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. E.g: you might want to define query parameters for paging if next_page_token is not None. """ @@ -207,12 +195,11 @@ def _request_params( def _request_body_data( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: - """ - Specifies how to populate the body of the request with a non-JSON payload. + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: + """Specifies how to populate the body of the request with a non-JSON payload. If returns a ready text that it will be sent as is. If returns a dict that it will be converted to a urlencoded form. @@ -230,12 +217,11 @@ def _request_body_data( def _request_body_json( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: - """ - Specifies how to populate the body of the request with a JSON payload. + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: + """Specifies how to populate the body of the request with a JSON payload. At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. """ @@ -252,9 +238,8 @@ def _request_body_json( def _paginator_path( self, - ) -> Optional[str]: - """ - If the paginator points to a path, follow it, else return nothing so the requester is used. + ) -> str | None: + """If the paginator points to a path, follow it, else return nothing so the requester is used. :param stream_state: :param stream_slice: :param next_page_token: @@ -264,11 +249,11 @@ def _paginator_path( def _parse_response( self, - response: Optional[requests.Response], + response: requests.Response | None, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: if not response: self._last_response = None @@ -289,7 +274,7 @@ def _parse_response( yield record @property # type: ignore - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: """The stream's primary key""" return self._primary_key @@ -298,9 +283,8 @@ def primary_key(self, value: str) -> None: if not isinstance(value, property): self._primary_key = value - def _next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: - """ - Specifies a pagination strategy. + def _next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: + """Specifies a pagination strategy. The value returned from this method is passed to most other methods in this class. Use it to form a request e.g: set headers or query params. @@ -312,8 +296,8 @@ def _fetch_next_page( self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[requests.Response]: + next_page_token: Mapping[str, Any] | None = None, + ) -> requests.Response | None: return self.requester.send_request( path=self._paginator_path(), stream_state=stream_state, @@ -344,7 +328,7 @@ def _fetch_next_page( # This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well. def _read_pages( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]], + records_generator_fn: Callable[[requests.Response | None], Iterable[StreamData]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[StreamData]: @@ -366,7 +350,7 @@ def _read_pages( def _read_single_page( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]], + records_generator_fn: Callable[[requests.Response | None], Iterable[StreamData]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[StreamData]: @@ -391,10 +375,9 @@ def _read_single_page( def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: - """ - Fetch a stream's records from an HTTP API source + """Fetch a stream's records from an HTTP API source :param records_schema: json schema to describe record :param stream_slice: The stream slice to read data for @@ -451,41 +434,37 @@ def read_records( def _get_most_recent_record( self, - current_most_recent: Optional[Record], - current_record: Optional[Record], + current_most_recent: Record | None, + current_record: Record | None, stream_slice: StreamSlice, - ) -> Optional[Record]: + ) -> Record | None: if self.cursor and current_record: if not current_most_recent: return current_record - else: - return ( - current_most_recent - if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) - else current_record - ) - else: - return None + return ( + current_most_recent + if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) + else current_record + ) + return None @staticmethod - def _extract_record(stream_data: StreamData, stream_slice: StreamSlice) -> Optional[Record]: - """ - As we allow the output of _read_pages to be StreamData, it can be multiple things. Therefore, we need to filter out and normalize + def _extract_record(stream_data: StreamData, stream_slice: StreamSlice) -> Record | None: + """As we allow the output of _read_pages to be StreamData, it can be multiple things. Therefore, we need to filter out and normalize to data to streamline the rest of the process. """ if isinstance(stream_data, Record): # Record is not part of `StreamData` but is the most common implementation of `Mapping[str, Any]` which is part of `StreamData` return stream_data - elif isinstance(stream_data, (dict, Mapping)): + if isinstance(stream_data, (dict, Mapping)): return Record(dict(stream_data), stream_slice) - elif isinstance(stream_data, AirbyteMessage) and stream_data.record: + if isinstance(stream_data, AirbyteMessage) and stream_data.record: return Record(stream_data.record.data, stream_slice) return None # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore - """ - Specifies the slices for this stream. See the stream slicing section of the docs for more information. + def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore + """Specifies the slices for this stream. See the stream slicing section of the docs for more information. :param sync_mode: :param cursor_field: @@ -506,10 +485,10 @@ def state(self, value: StreamState) -> None: def _parse_records( self, - response: Optional[requests.Response], + response: requests.Response | None, stream_state: Mapping[str, Any], records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice], + stream_slice: StreamSlice | None, ) -> Iterable[StreamData]: yield from self._parse_response( response, @@ -529,8 +508,7 @@ def _to_partition_key(to_serialize: Any) -> str: @dataclass class SimpleRetrieverTestReadDecorator(SimpleRetriever): - """ - In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of + """In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of slices that are queried throughout a read command. """ @@ -544,15 +522,15 @@ def __post_init__(self, options: Mapping[str, Any]) -> None: ) # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore + def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore return islice(super().stream_slices(), self.maximum_number_of_slices) def _fetch_next_page( self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[requests.Response]: + next_page_token: Mapping[str, Any] | None = None, + ) -> requests.Response | None: return self.requester.send_request( path=self._paginator_path(), stream_state=stream_state, diff --git a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py index a9b625e7..de0f98e8 100644 --- a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping +from typing import Any from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -13,8 +15,7 @@ @dataclass class DefaultSchemaLoader(SchemaLoader): - """ - Loads a schema from the default location or returns an empty schema for streams that have not defined their schema file yet. + """Loads a schema from the default location or returns an empty schema for streams that have not defined their schema file yet. Attributes: config (Config): The user-provided configuration as specified by the source's spec @@ -29,12 +30,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.default_loader = JsonFileSchemaLoader(parameters=parameters, config=self.config) def get_json_schema(self) -> Mapping[str, Any]: - """ - Attempts to retrieve a schema from the default filepath location or returns the empty schema if a schema cannot be found. + """Attempts to retrieve a schema from the default filepath location or returns the empty schema if a schema cannot be found. :return: The empty schema """ - try: return self.default_loader.get_json_schema() except OSError: diff --git a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py index 72a46b7e..691ef32c 100644 --- a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Mapping +from typing import Any from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -12,7 +14,7 @@ class InlineSchemaLoader(SchemaLoader): """Describes a stream's schema""" - schema: Dict[str, Any] + schema: dict[str, Any] parameters: InitVar[Mapping[str, Any]] def get_json_schema(self) -> Mapping[str, Any]: diff --git a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py index af51fe5d..a38cf9a9 100644 --- a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import pkgutil import sys +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Tuple, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -31,8 +33,7 @@ def _default_file_path() -> str: @dataclass class JsonFileSchemaLoader(ResourceSchemaLoader, SchemaLoader): - """ - Loads the schema from a json file + """Loads the schema from a json file Attributes: file_path (Union[InterpolatedString, str]): The path to the json file describing the schema @@ -43,7 +44,7 @@ class JsonFileSchemaLoader(ResourceSchemaLoader, SchemaLoader): config: Config parameters: InitVar[Mapping[str, Any]] - file_path: Union[InterpolatedString, str] = field(default="") + file_path: InterpolatedString | str = field(default="") def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not self.file_path: @@ -51,14 +52,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.file_path = InterpolatedString.create(self.file_path, parameters=parameters) def get_json_schema(self) -> Mapping[str, Any]: - # todo: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory + # TODO: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory # this would require that we find a creative solution to store or retrieve source_name in here since the files are mounted there json_schema_path = self._get_json_filepath() resource, schema_path = self.extract_resource_and_schema_path(json_schema_path) raw_json_file = pkgutil.get_data(resource, schema_path) if not raw_json_file: - raise IOError(f"Cannot find file {json_schema_path}") + raise OSError(f"Cannot find file {json_schema_path}") try: raw_schema = json.loads(raw_json_file) except ValueError as err: @@ -70,9 +71,8 @@ def _get_json_filepath(self) -> Any: return self.file_path.eval(self.config) # type: ignore # file_path is always cast to an interpolated string @staticmethod - def extract_resource_and_schema_path(json_schema_path: str) -> Tuple[str, str]: - """ - When the connector is running on a docker container, package_data is accessible from the resource (source_), so we extract + def extract_resource_and_schema_path(json_schema_path: str) -> tuple[str, str]: + """When the connector is running on a docker container, package_data is accessible from the resource (source_), so we extract the resource from the first part of the schema path and the remaining path is used to find the schema file. This is a slight hack to identify the source name while we are in the airbyte_cdk module. :param json_schema_path: The path to the schema JSON file diff --git a/airbyte_cdk/sources/declarative/schema/schema_loader.py b/airbyte_cdk/sources/declarative/schema/schema_loader.py index a6beb70a..b912eaa4 100644 --- a/airbyte_cdk/sources/declarative/schema/schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/schema_loader.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping +from typing import Any @dataclass diff --git a/airbyte_cdk/sources/declarative/spec/spec.py b/airbyte_cdk/sources/declarative/spec/spec.py index 05fa079b..dbf5398d 100644 --- a/airbyte_cdk/sources/declarative/spec/spec.py +++ b/airbyte_cdk/sources/declarative/spec/spec.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional +from typing import Any from airbyte_cdk.models import ( AdvancedAuth, @@ -15,8 +17,7 @@ @dataclass class Spec: - """ - Returns a connection specification made up of information about the connector and how it can be configured + """Returns a connection specification made up of information about the connector and how it can be configured Attributes: connection_specification (Mapping[str, Any]): information related to how a connector can be configured @@ -25,14 +26,11 @@ class Spec: connection_specification: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - documentation_url: Optional[str] = None - advanced_auth: Optional[AuthFlow] = None + documentation_url: str | None = None + advanced_auth: AuthFlow | None = None def generate_spec(self) -> ConnectorSpecification: - """ - Returns the connector specification according the spec block defined in the low code connector manifest. - """ - + """Returns the connector specification according the spec block defined in the low code connector manifest.""" obj: dict[str, Mapping[str, Any] | str | AdvancedAuth] = { "connectionSpecification": self.connection_specification } diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py index af9c438f..f08c4e4a 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, @@ -14,8 +15,7 @@ @dataclass class StreamSlicer(RequestOptionsProvider): - """ - Slices the stream into a subset of records. + """Slices the stream into a subset of records. Slices enable state checkpointing and data retrieval parallelization. The stream slicer keeps track of the cursor state as a dict of cursor_field -> cursor_value @@ -25,8 +25,7 @@ class StreamSlicer(RequestOptionsProvider): @abstractmethod def stream_slices(self) -> Iterable[StreamSlice]: - """ - Defines stream slices + """Defines stream slices :return: List of stream slices """ diff --git a/airbyte_cdk/sources/declarative/transformations/add_fields.py b/airbyte_cdk/sources/declarative/transformations/add_fields.py index fa920993..82995b3b 100644 --- a/airbyte_cdk/sources/declarative/transformations/add_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/add_fields.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any import dpath + from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, FieldPointer, StreamSlice, StreamState @@ -16,8 +19,8 @@ class AddedFieldDefinition: """Defines the field to add on a record""" path: FieldPointer - value: Union[InterpolatedString, str] - value_type: Optional[Type[Any]] + value: InterpolatedString | str + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @@ -27,14 +30,13 @@ class ParsedAddFieldDefinition: path: FieldPointer value: InterpolatedString - value_type: Optional[Type[Any]] + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @dataclass class AddFields(RecordTransformation): - """ - Transformation which adds field to an output record. The path of the added field can be nested. Adding nested fields will create all + """Transformation which adds field to an output record. The path of the added field can be nested. Adding nested fields will create all necessary parent objects (like mkdir -p). Adding fields to an array will extend the array to that index (filling intermediate indices with null values). So if you add a field at index 5 to the array ["value"], it will become ["value", null, null, null, null, "new_value"]. @@ -83,9 +85,9 @@ class AddFields(RecordTransformation): fields (List[AddedFieldDefinition]): A list of transformations (path and corresponding value) that will be added to the record """ - fields: List[AddedFieldDefinition] + fields: list[AddedFieldDefinition] parameters: InitVar[Mapping[str, Any]] - _parsed_fields: List[ParsedAddFieldDefinition] = field( + _parsed_fields: list[ParsedAddFieldDefinition] = field( init=False, repr=False, default_factory=list ) @@ -99,15 +101,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not isinstance(add_field.value, InterpolatedString): if not isinstance(add_field.value, str): raise f"Expected a string value for the AddFields transformation: {add_field}" - else: - self._parsed_fields.append( - ParsedAddFieldDefinition( - add_field.path, - InterpolatedString.create(add_field.value, parameters=parameters), - value_type=add_field.value_type, - parameters=parameters, - ) + self._parsed_fields.append( + ParsedAddFieldDefinition( + add_field.path, + InterpolatedString.create(add_field.value, parameters=parameters), + value_type=add_field.value_type, + parameters=parameters, ) + ) else: self._parsed_fields.append( ParsedAddFieldDefinition( @@ -120,10 +121,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: if config is None: config = {} @@ -133,5 +134,5 @@ def transform( value = parsed_field.value.eval(config, valid_types=valid_types, **kwargs) dpath.new(record, parsed_field.path, value) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return bool(self.__dict__ == other.__dict__) diff --git a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py index 53db3d49..29b3dec0 100644 --- a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py @@ -1,9 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -13,10 +14,10 @@ class KeysToLowerTransformation(RecordTransformation): def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: for key in set(record.keys()): record[key.lower()] = record.pop(key) diff --git a/airbyte_cdk/sources/declarative/transformations/remove_fields.py b/airbyte_cdk/sources/declarative/transformations/remove_fields.py index 8ac20a0d..12c682f2 100644 --- a/airbyte_cdk/sources/declarative/transformations/remove_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/remove_fields.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, List, Mapping, Optional +from typing import Any import dpath import dpath.exceptions + from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, FieldPointer, StreamSlice, StreamState @@ -14,8 +17,7 @@ @dataclass class RemoveFields(RecordTransformation): - """ - A transformation which removes fields from a record. The fields removed are designated using FieldPointers. + """A transformation which removes fields from a record. The fields removed are designated using FieldPointers. During transformation, if a field or any of its parents does not exist in the record, no error is thrown. If an input field pointer references an item in a list (e.g: ["k", 0] in the object {"k": ["a", "b", "c"]}) then @@ -39,7 +41,7 @@ class RemoveFields(RecordTransformation): field_pointers (List[FieldPointer]): pointers to the fields that should be removed """ - field_pointers: List[FieldPointer] + field_pointers: list[FieldPointer] parameters: InitVar[Mapping[str, Any]] condition: str = "" @@ -50,13 +52,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: - """ - :param record: The record to be transformed + """:param record: The record to be transformed :return: the input record with the requested fields removed """ for pointer in self.field_pointers: diff --git a/airbyte_cdk/sources/declarative/transformations/transformation.py b/airbyte_cdk/sources/declarative/transformations/transformation.py index f5b22642..8522d215 100644 --- a/airbyte_cdk/sources/declarative/transformations/transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/transformation.py @@ -1,30 +1,28 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @dataclass class RecordTransformation: - """ - Implementations of this class define transformations that can be applied to records of a stream. - """ + """Implementations of this class define transformations that can be applied to records of a stream.""" @abstractmethod def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: - """ - Transform a record by adding, deleting, or mutating fields directly from the record reference passed in argument. + """Transform a record by adding, deleting, or mutating fields directly from the record reference passed in argument. :param record: The input record to be transformed :param config: The user-provided configuration as specified by the source's spec diff --git a/airbyte_cdk/sources/declarative/types.py b/airbyte_cdk/sources/declarative/types.py index a4d0aeb1..57a7f5bb 100644 --- a/airbyte_cdk/sources/declarative/types.py +++ b/airbyte_cdk/sources/declarative/types.py @@ -13,6 +13,7 @@ StreamState, ) + # Note: This package originally contained class definitions for low-code CDK types, but we promoted them into the Python CDK. # We've migrated connectors in the repository to reference the new location, but these assignments are used to retain backwards # compatibility for sources created by OSS customers or on forks. This can be removed when we start bumping major versions. diff --git a/airbyte_cdk/sources/declarative/yaml_declarative_source.py b/airbyte_cdk/sources/declarative/yaml_declarative_source.py index cecb91a6..f4c22858 100644 --- a/airbyte_cdk/sources/declarative/yaml_declarative_source.py +++ b/airbyte_cdk/sources/declarative/yaml_declarative_source.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pkgutil -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any import yaml + from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, @@ -13,20 +16,18 @@ from airbyte_cdk.sources.types import ConnectionDefinition -class YamlDeclarativeSource(ConcurrentDeclarativeSource[List[AirbyteStateMessage]]): +class YamlDeclarativeSource(ConcurrentDeclarativeSource[list[AirbyteStateMessage]]): """Declarative source defined by a yaml file""" def __init__( self, path_to_yaml: str, debug: bool = False, - catalog: Optional[ConfiguredAirbyteCatalog] = None, - config: Optional[Mapping[str, Any]] = None, - state: Optional[List[AirbyteStateMessage]] = None, + catalog: ConfiguredAirbyteCatalog | None = None, + config: Mapping[str, Any] | None = None, + state: list[AirbyteStateMessage] | None = None, ) -> None: - """ - :param path_to_yaml: Path to the yaml file describing the source - """ + """:param path_to_yaml: Path to the yaml file describing the source""" self._path_to_yaml = path_to_yaml source_config = self._read_and_parse_yaml_file(path_to_yaml) @@ -44,8 +45,7 @@ def _read_and_parse_yaml_file(self, path_to_yaml_file: str) -> ConnectionDefinit if yaml_config: decoded_yaml = yaml_config.decode() return self._parse(decoded_yaml) - else: - return {} + return {} def _emit_manifest_debug_message(self, extra_args: dict[str, Any]) -> None: extra_args["path_to_yaml"] = self._path_to_yaml @@ -53,8 +53,7 @@ def _emit_manifest_debug_message(self, extra_args: dict[str, Any]) -> None: @staticmethod def _parse(connection_definition_str: str) -> ConnectionDefinition: - """ - Parses a yaml file into a manifest. Component references still exist in the manifest which will be + """Parses a yaml file into a manifest. Component references still exist in the manifest which will be resolved during the creating of the DeclarativeSource. :param connection_definition_str: yaml string to parse :return: The ConnectionDefinition parsed from connection_definition_str diff --git a/airbyte_cdk/sources/embedded/base_integration.py b/airbyte_cdk/sources/embedded/base_integration.py index c2e67408..f83a4966 100644 --- a/airbyte_cdk/sources/embedded/base_integration.py +++ b/airbyte_cdk/sources/embedded/base_integration.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Generic, Iterable, Optional, TypeVar +from collections.abc import Iterable +from typing import Generic, TypeVar from airbyte_cdk.connector import TConfig from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type @@ -16,6 +18,7 @@ from airbyte_cdk.sources.embedded.tools import get_defined_id from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit + TOutput = TypeVar("TOutput") @@ -26,17 +29,15 @@ def __init__(self, runner: SourceRunner[TConfig], config: TConfig): self.source = runner self.config = config - self.last_state: Optional[AirbyteStateMessage] = None + self.last_state: AirbyteStateMessage | None = None @abstractmethod - def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]: - """ - Turn an Airbyte record into the appropriate output type for the integration. - """ + def _handle_record(self, record: AirbyteRecordMessage, id: str | None) -> TOutput | None: + """Turn an Airbyte record into the appropriate output type for the integration.""" pass def _load_data( - self, stream_name: str, state: Optional[AirbyteStateMessage] = None + self, stream_name: str, state: AirbyteStateMessage | None = None ) -> Iterable[TOutput]: catalog = self.source.discover(self.config) stream = get_stream(catalog, stream_name) diff --git a/airbyte_cdk/sources/embedded/catalog.py b/airbyte_cdk/sources/embedded/catalog.py index 62c7a623..eb4521a2 100644 --- a/airbyte_cdk/sources/embedded/catalog.py +++ b/airbyte_cdk/sources/embedded/catalog.py @@ -1,8 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import List, Optional +from __future__ import annotations from airbyte_cdk.models import ( AirbyteCatalog, @@ -15,11 +14,11 @@ from airbyte_cdk.sources.embedded.tools import get_first -def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]: +def get_stream(catalog: AirbyteCatalog, stream_name: str) -> AirbyteStream | None: return get_first(catalog.streams, lambda s: s.name == stream_name) -def get_stream_names(catalog: AirbyteCatalog) -> List[str]: +def get_stream_names(catalog: AirbyteCatalog) -> list[str]: return [stream.name for stream in catalog.streams] @@ -27,8 +26,8 @@ def to_configured_stream( stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh, destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append, - cursor_field: Optional[List[str]] = None, - primary_key: Optional[List[List[str]]] = None, + cursor_field: list[str] | None = None, + primary_key: list[list[str]] | None = None, ) -> ConfiguredAirbyteStream: return ConfiguredAirbyteStream( stream=stream, @@ -40,7 +39,7 @@ def to_configured_stream( def to_configured_catalog( - configured_streams: List[ConfiguredAirbyteStream], + configured_streams: list[ConfiguredAirbyteStream], ) -> ConfiguredAirbyteCatalog: return ConfiguredAirbyteCatalog(streams=configured_streams) diff --git a/airbyte_cdk/sources/embedded/runner.py b/airbyte_cdk/sources/embedded/runner.py index 43217f15..3b0a8f5e 100644 --- a/airbyte_cdk/sources/embedded/runner.py +++ b/airbyte_cdk/sources/embedded/runner.py @@ -1,11 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Generic, Iterable, Optional +from collections.abc import Iterable +from typing import Generic from airbyte_cdk.connector import TConfig from airbyte_cdk.models import ( @@ -32,7 +33,7 @@ def read( self, config: TConfig, catalog: ConfiguredAirbyteCatalog, - state: Optional[AirbyteStateMessage], + state: AirbyteStateMessage | None, ) -> Iterable[AirbyteMessage]: pass @@ -52,6 +53,6 @@ def read( self, config: TConfig, catalog: ConfiguredAirbyteCatalog, - state: Optional[AirbyteStateMessage], + state: AirbyteStateMessage | None, ) -> Iterable[AirbyteMessage]: return self._source.read(self._logger, config, catalog, state=[state] if state else []) diff --git a/airbyte_cdk/sources/embedded/tools.py b/airbyte_cdk/sources/embedded/tools.py index 6bffa1a0..207f19bb 100644 --- a/airbyte_cdk/sources/embedded/tools.py +++ b/airbyte_cdk/sources/embedded/tools.py @@ -1,20 +1,23 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Callable, Dict, Iterable, Optional +from collections.abc import Callable, Iterable +from typing import Any import dpath + from airbyte_cdk.models import AirbyteStream def get_first( iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True -) -> Optional[Any]: +) -> Any | None: return next(filter(predicate, iterable), None) -def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]: +def get_defined_id(stream: AirbyteStream, data: dict[str, Any]) -> str | None: if not stream.source_defined_primary_key: return None primary_key = [] diff --git a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py index c0234ca1..7f64b8aa 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING from airbyte_cdk.sources import Source from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy @@ -16,6 +17,7 @@ ) from airbyte_cdk.sources.streams.core import Stream + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream @@ -23,10 +25,9 @@ class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy): @abstractmethod def check_availability( - self, stream: Stream, logger: logging.Logger, _: Optional[Source] - ) -> Tuple[bool, Optional[str]]: - """ - Perform a connection check for the stream. + self, stream: Stream, logger: logging.Logger, _: Source | None + ) -> tuple[bool, str | None]: + """Perform a connection check for the stream. Returns (True, None) if successful, otherwise (False, ). """ @@ -34,10 +35,9 @@ def check_availability( @abstractmethod def check_availability_and_parsability( - self, stream: "AbstractFileBasedStream", logger: logging.Logger, _: Optional[Source] - ) -> Tuple[bool, Optional[str]]: - """ - Performs a connection check for the stream, as well as additional checks that + self, stream: AbstractFileBasedStream, logger: logging.Logger, _: Source | None + ) -> tuple[bool, str | None]: + """Performs a connection check for the stream, as well as additional checks that verify that the connection is working as expected. Returns (True, None) if successful, otherwise (False, ). @@ -46,7 +46,7 @@ def check_availability_and_parsability( class AbstractFileBasedAvailabilityStrategyWrapper(AbstractAvailabilityStrategy): - def __init__(self, stream: "AbstractFileBasedStream"): + def __init__(self, stream: AbstractFileBasedStream): self.stream = stream def check_availability(self, logger: logging.Logger) -> StreamAvailability: @@ -57,9 +57,7 @@ def check_availability(self, logger: logging.Logger) -> StreamAvailability: return StreamAvailable() return StreamUnavailable(reason or "") - def check_availability_and_parsability( - self, logger: logging.Logger - ) -> Tuple[bool, Optional[str]]: + def check_availability_and_parsability(self, logger: logging.Logger) -> tuple[bool, str | None]: return self.stream.availability_strategy.check_availability_and_parsability( self.stream, logger, None ) diff --git a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py index cf985d9e..5976db55 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import traceback -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING from airbyte_cdk import AirbyteTracedException from airbyte_cdk.sources import Source @@ -20,6 +21,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream @@ -29,10 +31,9 @@ def __init__(self, stream_reader: AbstractFileBasedStreamReader): self.stream_reader = stream_reader def check_availability( - self, stream: "AbstractFileBasedStream", logger: logging.Logger, _: Optional[Source] - ) -> Tuple[bool, Optional[str]]: # type: ignore[override] - """ - Perform a connection check for the stream (verify that we can list files from the stream). + self, stream: AbstractFileBasedStream, logger: logging.Logger, _: Source | None + ) -> tuple[bool, str | None]: # type: ignore[override] + """Perform a connection check for the stream (verify that we can list files from the stream). Returns (True, None) if successful, otherwise (False, ). """ @@ -44,10 +45,9 @@ def check_availability( return True, None def check_availability_and_parsability( - self, stream: "AbstractFileBasedStream", logger: logging.Logger, _: Optional[Source] - ) -> Tuple[bool, Optional[str]]: - """ - Perform a connection check for the stream. + self, stream: AbstractFileBasedStream, logger: logging.Logger, _: Source | None + ) -> tuple[bool, str | None]: + """Perform a connection check for the stream. Returns (True, None) if successful, otherwise (False, ). @@ -69,7 +69,7 @@ def check_availability_and_parsability( return False, config_check_error_message try: file = self._check_list_files(stream) - if not parser.parser_max_n_files_for_parsability == 0: + if parser.parser_max_n_files_for_parsability != 0: self._check_parse_record(stream, file, logger) else: # If the parser is set to not check parsability, we still want to check that we can open the file. @@ -82,9 +82,8 @@ def check_availability_and_parsability( return True, None - def _check_list_files(self, stream: "AbstractFileBasedStream") -> RemoteFile: - """ - Check that we can list files from the stream. + def _check_list_files(self, stream: AbstractFileBasedStream) -> RemoteFile: + """Check that we can list files from the stream. Returns the first file if successful, otherwise raises a CheckAvailabilityError. """ @@ -102,7 +101,7 @@ def _check_list_files(self, stream: "AbstractFileBasedStream") -> RemoteFile: return file def _check_parse_record( - self, stream: "AbstractFileBasedStream", file: RemoteFile, logger: logging.Logger + self, stream: AbstractFileBasedStream, file: RemoteFile, logger: logging.Logger ) -> None: parser = stream.get_parser() @@ -135,4 +134,4 @@ def _check_parse_record( file=file.uri, ) - return None + return diff --git a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py index ee220388..de0f10ad 100644 --- a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py +++ b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py @@ -1,16 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy from abc import abstractmethod -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal import dpath +from pydantic.v1 import AnyUrl, BaseModel, Field + from airbyte_cdk import OneOfOptionConfig from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.utils import schema_helpers -from pydantic.v1 import AnyUrl, BaseModel, Field class DeliverRecords(BaseModel): @@ -32,12 +34,11 @@ class Config(OneOfOptionConfig): class AbstractFileBasedSpec(BaseModel): - """ - Used during spec; allows the developer to configure the cloud provider specific options + """Used during spec; allows the developer to configure the cloud provider specific options that are needed when users configure a file-based source. """ - start_date: Optional[str] = Field( + start_date: str | None = Field( title="Start Date", description="UTC date and time in the format 2017-01-25T00:00:00.000000Z. Any file modified before this date will not be replicated.", examples=["2021-01-01T00:00:00.000000Z"], @@ -47,13 +48,13 @@ class AbstractFileBasedSpec(BaseModel): order=1, ) - streams: List[FileBasedStreamConfig] = Field( + streams: list[FileBasedStreamConfig] = Field( title="The list of streams to sync", description='Each instance of this configuration defines a stream. Use this to define which files belong in the stream, their format, and how they should be parsed and validated. When sending data to warehouse destination such as Snowflake or BigQuery, each stream is a separate table.', order=10, ) - delivery_method: Union[DeliverRecords, DeliverRawFiles] = Field( + delivery_method: DeliverRecords | DeliverRawFiles = Field( title="Delivery Method", discriminator="delivery_type", type="object", @@ -67,17 +68,13 @@ class AbstractFileBasedSpec(BaseModel): @classmethod @abstractmethod def documentation_url(cls) -> AnyUrl: - """ - :return: link to docs page for this source e.g. "https://docs.airbyte.com/integrations/sources/s3" - """ + """:return: link to docs page for this source e.g. "https://docs.airbyte.com/integrations/sources/s3" """ @classmethod - def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: - """ - Generates the mapping comprised of the config fields - """ + def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Generates the mapping comprised of the config fields""" schema = super().schema(*args, **kwargs) - transformed_schema: Dict[str, Any] = copy.deepcopy(schema) + transformed_schema: dict[str, Any] = copy.deepcopy(schema) schema_helpers.expand_refs(transformed_schema) cls.replace_enum_allOf_and_anyOf(transformed_schema) cls.remove_discriminator(transformed_schema) @@ -85,14 +82,13 @@ def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: return transformed_schema @staticmethod - def remove_discriminator(schema: Dict[str, Any]) -> None: - """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" + def remove_discriminator(schema: dict[str, Any]) -> None: + """Pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" dpath.delete(schema, "properties/**/discriminator") @staticmethod - def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - allOfs are not supported by the UI, but pydantic is automatically writing them for enums. + def replace_enum_allOf_and_anyOf(schema: dict[str, Any]) -> dict[str, Any]: + """AllOfs are not supported by the UI, but pydantic is automatically writing them for enums. Unpacks the enums under allOf and moves them up a level under the enum key anyOfs are also not supported by the UI, so we replace them with the similar oneOf, with the additional validation that an incoming config only matches exactly one of a field's types. @@ -134,7 +130,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: return schema @staticmethod - def move_enum_to_root(object_property: Dict[str, Any]) -> None: + def move_enum_to_root(object_property: dict[str, Any]) -> None: if "allOf" in object_property and "enum" in object_property["allOf"][0]: object_property["enum"] = object_property["allOf"][0]["enum"] object_property.pop("allOf") diff --git a/airbyte_cdk/sources/file_based/config/avro_format.py b/airbyte_cdk/sources/file_based/config/avro_format.py index ac8fafef..e3253f4f 100644 --- a/airbyte_cdk/sources/file_based/config/avro_format.py +++ b/airbyte_cdk/sources/file_based/config/avro_format.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from pydantic.v1 import BaseModel, Field from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig -from pydantic.v1 import BaseModel, Field class AvroFormat(BaseModel): diff --git a/airbyte_cdk/sources/file_based/config/csv_format.py b/airbyte_cdk/sources/file_based/config/csv_format.py index 83789c45..2f99f499 100644 --- a/airbyte_cdk/sources/file_based/config/csv_format.py +++ b/airbyte_cdk/sources/file_based/config/csv_format.py @@ -1,15 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import codecs from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any -from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig from pydantic.v1 import BaseModel, Field, root_validator, validator from pydantic.v1.error_wrappers import ValidationError +from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig + class InferenceType(Enum): NONE = "None" @@ -59,7 +61,7 @@ class Config(OneOfOptionConfig): CsvHeaderDefinitionType.USER_PROVIDED.value, const=True, ) - column_names: List[str] = Field( + column_names: list[str] = Field( title="Column Names", description="The column names that will be used while emitting the CSV records", ) @@ -68,7 +70,7 @@ def has_header_row(self) -> bool: return False @validator("column_names") - def validate_column_names(cls, v: List[str]) -> List[str]: + def validate_column_names(cls, v: list[str]) -> list[str]: if not v: raise ValueError( "At least one column name needs to be provided when using user provided headers" @@ -99,12 +101,12 @@ class Config(OneOfOptionConfig): default='"', description="The character used for quoting CSV values. To disallow quoting, make this field blank.", ) - escape_char: Optional[str] = Field( + escape_char: str | None = Field( title="Escape Character", default=None, description="The character used for escaping special characters. To disallow escaping, leave this field blank.", ) - encoding: Optional[str] = Field( + encoding: str | None = Field( default="utf8", description='The character encoding of the CSV data. Leave blank to default to UTF8. See list of python encodings for allowable options.', ) @@ -113,7 +115,7 @@ class Config(OneOfOptionConfig): default=True, description="Whether two quotes in a quoted CSV value denote a single quote in the data.", ) - null_values: Set[str] = Field( + null_values: set[str] = Field( title="Null Values", default=[], description="A set of case-sensitive strings that should be interpreted as null values. For example, if the value 'NA' should be interpreted as null, enter 'NA' in this field.", @@ -133,19 +135,17 @@ class Config(OneOfOptionConfig): default=0, description="The number of rows to skip after the header row.", ) - header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = ( - Field( - title="CSV Header Definition", - default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), - description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", - ) + header_definition: CsvHeaderFromCsv | CsvHeaderAutogenerated | CsvHeaderUserProvided = Field( + title="CSV Header Definition", + default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), + description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", ) - true_values: Set[str] = Field( + true_values: set[str] = Field( title="True Values", default=DEFAULT_TRUE_VALUES, description="A set of case-sensitive strings that should be interpreted as true values.", ) - false_values: Set[str] = Field( + false_values: set[str] = Field( title="False Values", default=DEFAULT_FALSE_VALUES, description="A set of case-sensitive strings that should be interpreted as false values.", @@ -193,7 +193,7 @@ def validate_encoding(cls, v: str) -> str: return v @root_validator - def validate_optional_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def validate_optional_args(cls, values: dict[str, Any]) -> dict[str, Any]: definition_type = values.get("header_definition_type") column_names = values.get("user_provided_column_names") if definition_type == CsvHeaderDefinitionType.USER_PROVIDED and not column_names: diff --git a/airbyte_cdk/sources/file_based/config/excel_format.py b/airbyte_cdk/sources/file_based/config/excel_format.py index 02a4f52d..46ef7916 100644 --- a/airbyte_cdk/sources/file_based/config/excel_format.py +++ b/airbyte_cdk/sources/file_based/config/excel_format.py @@ -1,10 +1,12 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig from pydantic.v1 import BaseModel, Field +from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig + class ExcelFormat(BaseModel): class Config(OneOfOptionConfig): diff --git a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py index 5d92f6f0..79fca918 100644 --- a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py +++ b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py @@ -1,9 +1,13 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from enum import Enum -from typing import Any, List, Mapping, Optional, Union +from typing import Any, Optional + +from pydantic.v1 import BaseModel, Field, validator from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat @@ -13,9 +17,9 @@ from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError from airbyte_cdk.sources.file_based.schema_helpers import type_mapping_to_jsonschema -from pydantic.v1 import BaseModel, Field, validator -PrimaryKeyType = Optional[Union[str, List[str]]] + +PrimaryKeyType = Optional[str | list[str]] class ValidationPolicy(Enum): @@ -26,13 +30,13 @@ class ValidationPolicy(Enum): class FileBasedStreamConfig(BaseModel): name: str = Field(title="Name", description="The name of the stream.") - globs: Optional[List[str]] = Field( + globs: list[str] | None = Field( default=["**"], title="Globs", description='The pattern used to specify which files should be selected from the file system. For more information on glob pattern matching look here.', order=1, ) - legacy_prefix: Optional[str] = Field( + legacy_prefix: str | None = Field( title="Legacy Prefix", description="The path prefix configured in v3 versions of the S3 connector. This option is deprecated in favor of a single glob.", airbyte_hidden=True, @@ -42,11 +46,11 @@ class FileBasedStreamConfig(BaseModel): description="The name of the validation policy that dictates sync behavior when a record does not adhere to the stream schema.", default=ValidationPolicy.emit_record, ) - input_schema: Optional[str] = Field( + input_schema: str | None = Field( title="Input Schema", description="The schema that will be used to validate records extracted from the file. This will override the stream schema that is auto-detected from incoming files.", ) - primary_key: Optional[str] = Field( + primary_key: str | None = Field( title="Primary Key", description="The column or columns (for a composite key) that serves as the unique identifier of a record. If empty, the primary key will default to the parser's default primary key.", airbyte_hidden=True, # Users can create/modify primary keys in the connection configuration so we shouldn't duplicate it here. @@ -56,9 +60,9 @@ class FileBasedStreamConfig(BaseModel): description="When the state history of the file store is full, syncs will only read files that were last modified in the provided day range.", default=3, ) - format: Union[ - AvroFormat, CsvFormat, JsonlFormat, ParquetFormat, UnstructuredFormat, ExcelFormat - ] = Field( + format: ( + AvroFormat | CsvFormat | JsonlFormat | ParquetFormat | UnstructuredFormat | ExcelFormat + ) = Field( title="Format", description="The configuration options that are used to alter how to read incoming files that deviate from the standard formatting.", ) @@ -67,7 +71,7 @@ class FileBasedStreamConfig(BaseModel): description="When enabled, syncs will not validate or structure records against the stream's schema.", default=False, ) - recent_n_files_to_read_for_schema_discovery: Optional[int] = Field( + recent_n_files_to_read_for_schema_discovery: int | None = Field( title="Files To Read For Schema Discover", description="The number of resent files which will be used to discover the schema for this stream.", default=None, @@ -75,17 +79,15 @@ class FileBasedStreamConfig(BaseModel): ) @validator("input_schema", pre=True) - def validate_input_schema(cls, v: Optional[str]) -> Optional[str]: + def validate_input_schema(cls, v: str | None) -> str | None: if v: if type_mapping_to_jsonschema(v): return v - else: - raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA) + raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA) return None - def get_input_schema(self) -> Optional[Mapping[str, Any]]: - """ - User defined input_schema is defined as a string in the config. This method takes the string representation + def get_input_schema(self) -> Mapping[str, Any] | None: + """User defined input_schema is defined as a string in the config. This method takes the string representation and converts it into a Mapping[str, Any] which is used by file-based CDK components. """ if self.input_schema: diff --git a/airbyte_cdk/sources/file_based/config/jsonl_format.py b/airbyte_cdk/sources/file_based/config/jsonl_format.py index 1d9ed54f..9fe4e9d1 100644 --- a/airbyte_cdk/sources/file_based/config/jsonl_format.py +++ b/airbyte_cdk/sources/file_based/config/jsonl_format.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig from pydantic.v1 import BaseModel, Field +from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig + class JsonlFormat(BaseModel): class Config(OneOfOptionConfig): diff --git a/airbyte_cdk/sources/file_based/config/parquet_format.py b/airbyte_cdk/sources/file_based/config/parquet_format.py index 7c40f8e3..d7cc7a61 100644 --- a/airbyte_cdk/sources/file_based/config/parquet_format.py +++ b/airbyte_cdk/sources/file_based/config/parquet_format.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from pydantic.v1 import BaseModel, Field from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig -from pydantic.v1 import BaseModel, Field class ParquetFormat(BaseModel): diff --git a/airbyte_cdk/sources/file_based/config/unstructured_format.py b/airbyte_cdk/sources/file_based/config/unstructured_format.py index dcebd951..854b35b0 100644 --- a/airbyte_cdk/sources/file_based/config/unstructured_format.py +++ b/airbyte_cdk/sources/file_based/config/unstructured_format.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import List, Literal, Optional, Union +from typing import Literal -from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig from pydantic.v1 import BaseModel, Field +from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig + class LocalProcessingConfigModel(BaseModel): mode: Literal["local"] = Field("local", const=True) @@ -49,7 +51,7 @@ class APIProcessingConfigModel(BaseModel): examples=["https://api.unstructured.com"], ) - parameters: Optional[List[APIParameterConfigModel]] = Field( + parameters: list[APIParameterConfigModel] | None = Field( default=[], always_show=True, title="Additional URL Parameters", @@ -89,10 +91,7 @@ class Config(OneOfOptionConfig): description="The strategy used to parse documents. `fast` extracts text directly from the document which doesn't work for all files. `ocr_only` is more reliable, but slower. `hi_res` is the most reliable, but requires an API key and a hosted instance of unstructured and can't be used with local mode. See the unstructured.io documentation for more details: https://unstructured-io.github.io/unstructured/core/partition.html#partition-pdf", ) - processing: Union[ - LocalProcessingConfigModel, - APIProcessingConfigModel, - ] = Field( + processing: LocalProcessingConfigModel | APIProcessingConfigModel = Field( default=LocalProcessingConfigModel(mode="local"), title="Processing", description="Processing configuration", diff --git a/airbyte_cdk/sources/file_based/discovery_policy/abstract_discovery_policy.py b/airbyte_cdk/sources/file_based/discovery_policy/abstract_discovery_policy.py index 115382cf..671c0324 100644 --- a/airbyte_cdk/sources/file_based/discovery_policy/abstract_discovery_policy.py +++ b/airbyte_cdk/sources/file_based/discovery_policy/abstract_discovery_policy.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod @@ -8,8 +9,7 @@ class AbstractDiscoveryPolicy(ABC): - """ - Used during discovery; allows the developer to configure the number of concurrent + """Used during discovery; allows the developer to configure the number of concurrent requests to send to the source, and the number of files to use for schema discovery. """ diff --git a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py index f651c2ce..383b26b4 100644 --- a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py +++ b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py @@ -1,19 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_cdk.sources.file_based.discovery_policy.abstract_discovery_policy import ( AbstractDiscoveryPolicy, ) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser + DEFAULT_N_CONCURRENT_REQUESTS = 10 DEFAULT_MAX_N_FILES_FOR_STREAM_SCHEMA_INFERENCE = 10 class DefaultDiscoveryPolicy(AbstractDiscoveryPolicy): - """ - Default number of concurrent requests to send to the source on discover, and number + """Default number of concurrent requests to send to the source on discover, and number of files to use for schema inference. """ diff --git a/airbyte_cdk/sources/file_based/exceptions.py b/airbyte_cdk/sources/file_based/exceptions.py index 1c5ce0b1..b91d48d1 100644 --- a/airbyte_cdk/sources/file_based/exceptions.py +++ b/airbyte_cdk/sources/file_based/exceptions.py @@ -1,9 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from enum import Enum -from typing import Any, List, Union +from typing import Any, Union from airbyte_cdk.models import AirbyteMessage, FailureType from airbyte_cdk.utils import AirbyteTracedException @@ -39,11 +40,9 @@ class FileBasedSourceError(Enum): class FileBasedErrorsCollector: - """ - The placeholder for all errors collected. - """ + """The placeholder for all errors collected.""" - errors: List[AirbyteMessage] = [] + errors: list[AirbyteMessage] = [] def yield_and_raise_collected(self) -> Any: if self.errors: @@ -112,8 +111,7 @@ class ErrorListingFiles(BaseFileBasedSourceError): class CustomFileBasedException(AirbyteTracedException): - """ - A specialized exception for file-based connectors. + """A specialized exception for file-based connectors. This exception is designed to bypass the default error handling in the file-based CDK, allowing the use of custom error messages. """ diff --git a/airbyte_cdk/sources/file_based/file_based_source.py b/airbyte_cdk/sources/file_based/file_based_source.py index 2c5758b2..ea66da01 100644 --- a/airbyte_cdk/sources/file_based/file_based_source.py +++ b/airbyte_cdk/sources/file_based/file_based_source.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import traceback from abc import ABC from collections import Counter -from typing import Any, Iterator, List, Mapping, Optional, Tuple, Type, Union +from collections.abc import Iterator, Mapping +from typing import Any + +from pydantic.v1.error_wrappers import ValidationError from airbyte_cdk.logger import AirbyteLogFormatter, init_logger from airbyte_cdk.models import ( @@ -60,7 +64,7 @@ from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from airbyte_cdk.utils.analytics_message import create_analytics_message from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from pydantic.v1.error_wrappers import ValidationError + DEFAULT_CONCURRENCY = 100 MAX_CONCURRENCY = 100 @@ -74,18 +78,18 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC): def __init__( self, stream_reader: AbstractFileBasedStreamReader, - spec_class: Type[AbstractFileBasedSpec], - catalog: Optional[ConfiguredAirbyteCatalog], - config: Optional[Mapping[str, Any]], - state: Optional[List[AirbyteStateMessage]], - availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None, + spec_class: type[AbstractFileBasedSpec], + catalog: ConfiguredAirbyteCatalog | None, + config: Mapping[str, Any] | None, + state: list[AirbyteStateMessage] | None, + availability_strategy: AbstractFileBasedAvailabilityStrategy | None = None, discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), - parsers: Mapping[Type[Any], FileTypeParser] = default_parsers, + parsers: Mapping[type[Any], FileTypeParser] = default_parsers, validation_policies: Mapping[ ValidationPolicy, AbstractSchemaValidationPolicy ] = DEFAULT_SCHEMA_VALIDATION_POLICIES, - cursor_cls: Type[ - Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor] + cursor_cls: type[ + AbstractConcurrentFileBasedCursor | AbstractFileBasedCursor ] = FileBasedConcurrentCursor, ): self.stream_reader = stream_reader @@ -105,7 +109,7 @@ def __init__( self.cursor_cls = cursor_cls self.logger = init_logger(f"airbyte.{self.name}") self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector() - self._message_repository: Optional[MessageRepository] = None + self._message_repository: MessageRepository | None = None concurrent_source = ConcurrentSource.create( MAX_CONCURRENCY, INITIAL_N_PARTITIONS, @@ -126,9 +130,8 @@ def message_repository(self) -> MessageRepository: def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: - """ - Check that the source can be accessed using the user-provided configuration. + ) -> tuple[bool, Any | None]: + """Check that the source can be accessed using the user-provided configuration. For each stream, verify that we can list and read files. @@ -190,7 +193,7 @@ def check_connection( message=f"{errors[0]}", failure_type=FailureType.config_error, ) - elif len(errors) > 1: + if len(errors) > 1: raise AirbyteTracedException( internal_message="\n".join(tracebacks), message=f"{len(errors)} streams with errors: {', '.join(error for error in errors)}", @@ -199,11 +202,8 @@ def check_connection( return not bool(errors), (errors or None) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: - """ - Return a list of this source's streams. - """ - + def streams(self, config: Mapping[str, Any]) -> list[Stream]: + """Return a list of this source's streams.""" if self.catalog: state_manager = ConnectorStateManager(state=self.state) else: @@ -214,7 +214,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: try: parsed_config = self._get_parsed_config(config) self.stream_reader.config = parsed_config - streams: List[Stream] = [] + streams: list[Stream] = [] for stream_config in parsed_config.streams: # Like state_manager, `catalog_stream` may be None during `check` catalog_stream = self._get_stream_from_catalog(stream_config) @@ -296,7 +296,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: def _make_default_stream( self, stream_config: FileBasedStreamConfig, - cursor: Optional[AbstractFileBasedCursor], + cursor: AbstractFileBasedCursor | None, use_file_transfer: bool = False, ) -> AbstractFileBasedStream: return DefaultFileBasedStream( @@ -314,14 +314,14 @@ def _make_default_stream( def _get_stream_from_catalog( self, stream_config: FileBasedStreamConfig - ) -> Optional[AirbyteStream]: + ) -> AirbyteStream | None: if self.catalog: for stream in self.catalog.streams or []: if stream.stream.name == stream_config.name: return stream.stream return None - def _get_sync_mode_from_catalog(self, stream_name: str) -> Optional[SyncMode]: + def _get_sync_mode_from_catalog(self, stream_name: str) -> SyncMode | None: if self.catalog: for catalog_stream in self.catalog.streams: if stream_name == catalog_stream.stream.name: @@ -334,7 +334,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: yield from super().read(logger, config, catalog, state) # emit all the errors collected @@ -347,10 +347,7 @@ def read( yield create_analytics_message(f"file-cdk-{parser}-stream-count", count) def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: - """ - Returns the specification describing what fields can be configured by a user when setting up a file-based source. - """ - + """Returns the specification describing what fields can be configured by a user when setting up a file-based source.""" return ConnectorSpecification( documentationUrl=self.spec_class.documentation_url(), connectionSpecification=self.spec_class.schema(), diff --git a/airbyte_cdk/sources/file_based/file_based_stream_reader.py b/airbyte_cdk/sources/file_based/file_based_stream_reader.py index f8a9f89f..b04da502 100644 --- a/airbyte_cdk/sources/file_based/file_based_stream_reader.py +++ b/airbyte_cdk/sources/file_based/file_based_stream_reader.py @@ -1,18 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Iterable from datetime import datetime from enum import Enum from io import IOBase from os import makedirs, path -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any + +from wcmatch.glob import GLOBSTAR, globmatch from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from wcmatch.glob import GLOBSTAR, globmatch class FileReadMode(Enum): @@ -27,14 +30,13 @@ def __init__(self) -> None: self._config = None @property - def config(self) -> Optional[AbstractFileBasedSpec]: + def config(self) -> AbstractFileBasedSpec | None: return self._config @config.setter @abstractmethod def config(self, value: AbstractFileBasedSpec) -> None: - """ - FileBasedSource reads the config from disk and parses it, and once parsed, the source sets the config on its StreamReader. + """FileBasedSource reads the config from disk and parses it, and once parsed, the source sets the config on its StreamReader. Note: FileBasedSource only requires the keys defined in the abstract config, whereas concrete implementations of StreamReader will require keys that (for example) allow it to authenticate with the 3rd party. @@ -46,10 +48,9 @@ def config(self, value: AbstractFileBasedSpec) -> None: @abstractmethod def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: - """ - Return a file handle for reading. + """Return a file handle for reading. Many sources will be able to use smart_open to implement this method, for example: @@ -62,15 +63,13 @@ def open_file( @abstractmethod def get_matching_files( self, - globs: List[str], - prefix: Optional[str], + globs: list[str], + prefix: str | None, logger: logging.Logger, ) -> Iterable[RemoteFile]: - """ - Return all files that match any of the globs. + """Return all files that match any of the globs. Example: - The source has files "a.json", "foo/a.json", "foo/bar/a.json" If globs = ["*.json"] then this method returns ["a.json"]. @@ -83,11 +82,9 @@ def get_matching_files( ... def filter_files_by_globs_and_start_date( - self, files: List[RemoteFile], globs: List[str] + self, files: list[RemoteFile], globs: list[str] ) -> Iterable[RemoteFile]: - """ - Utility method for filtering files based on globs. - """ + """Utility method for filtering files based on globs.""" start_date = ( datetime.strptime(self.config.start_date, self.DATE_TIME_FORMAT) if self.config and self.config.start_date @@ -112,16 +109,14 @@ def file_size(self, file: RemoteFile) -> int: ... @staticmethod - def file_matches_globs(file: RemoteFile, globs: List[str]) -> bool: + def file_matches_globs(file: RemoteFile, globs: list[str]) -> bool: # Use the GLOBSTAR flag to enable recursive ** matching # (https://facelessuser.github.io/wcmatch/wcmatch/#globstar) return any(globmatch(file.uri, g, flags=GLOBSTAR) for g in globs) @staticmethod - def get_prefixes_from_globs(globs: List[str]) -> Set[str]: - """ - Utility method for extracting prefixes from the globs. - """ + def get_prefixes_from_globs(globs: list[str]) -> set[str]: + """Utility method for extracting prefixes from the globs.""" prefixes = {glob.split("*")[0] for glob in globs} return set(filter(lambda x: bool(x), prefixes)) @@ -137,9 +132,8 @@ def use_file_transfer(self) -> bool: @abstractmethod def get_file( self, file: RemoteFile, local_directory: str, logger: logging.Logger - ) -> Dict[str, Any]: - """ - This is required for connectors that will support writing to + ) -> dict[str, Any]: + """This is required for connectors that will support writing to files. It will handle the logic to download,get,read,acquire or whatever is more efficient to get a file from the source. @@ -148,7 +142,7 @@ def get_file( local_directory (str): The local directory path where the file will be downloaded. logger (logging.Logger): Logger for logging information and errors. - Returns: + Returns: dict: A dictionary containing the following: - "file_url" (str): The absolute path of the downloaded file. - "bytes" (int): The file size in bytes. @@ -159,7 +153,7 @@ def get_file( ... @staticmethod - def _get_file_transfer_paths(file: RemoteFile, local_directory: str) -> List[str]: + def _get_file_transfer_paths(file: RemoteFile, local_directory: str) -> list[str]: # Remove left slashes from source path format to make relative path for writing locally file_relative_path = file.uri.lstrip("/") local_file_path = path.join(local_directory, file_relative_path) diff --git a/airbyte_cdk/sources/file_based/file_types/avro_parser.py b/airbyte_cdk/sources/file_based/file_types/avro_parser.py index a1535eaa..7a92a353 100644 --- a/airbyte_cdk/sources/file_based/file_types/avro_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/avro_parser.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple +from collections.abc import Iterable, Mapping +from typing import Any import fastavro + from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError @@ -17,6 +20,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType + AVRO_TYPE_TO_JSON_TYPE = { "null": "null", "boolean": "boolean", @@ -45,10 +49,8 @@ class AvroParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - AvroParser does not require config checks, implicit pydantic validation is enough. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """AvroParser does not require config checks, implicit pydantic validation is enough.""" return True, None async def infer_schema( @@ -65,7 +67,7 @@ async def infer_schema( with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: avro_reader = fastavro.reader(fp) avro_schema = avro_reader.writer_schema - if not avro_schema["type"] == "record": + if avro_schema["type"] != "record": unsupported_type = avro_schema["type"] raise ValueError( f"Only record based avro files are supported. Found {unsupported_type}" @@ -98,7 +100,7 @@ def _convert_avro_type_to_json( for object_field in avro_field["fields"] }, } - elif avro_field["type"] == "array": + if avro_field["type"] == "array": if "items" not in avro_field: raise ValueError( f"{field_name} array type does not have a required field items" @@ -109,7 +111,7 @@ def _convert_avro_type_to_json( avro_format, "", avro_field["items"] ), } - elif avro_field["type"] == "enum": + if avro_field["type"] == "enum": if "symbols" not in avro_field: raise ValueError( f"{field_name} enum type does not have a required field symbols" @@ -117,7 +119,7 @@ def _convert_avro_type_to_json( if "name" not in avro_field: raise ValueError(f"{field_name} enum type does not have a required field name") return {"type": "string", "enum": avro_field["symbols"]} - elif avro_field["type"] == "map": + if avro_field["type"] == "map": if "values" not in avro_field: raise ValueError(f"{field_name} map type does not have a required field values") return { @@ -126,7 +128,7 @@ def _convert_avro_type_to_json( avro_format, "", avro_field["values"] ), } - elif avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration": + if avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration": if "size" not in avro_field: raise ValueError(f"{field_name} fixed type does not have a required field size") if not isinstance(avro_field["size"], int): @@ -135,7 +137,7 @@ def _convert_avro_type_to_json( "type": "string", "pattern": f"^[0-9A-Fa-f]{{{avro_field['size'] * 2}}}$", } - elif avro_field.get("logicalType") == "decimal": + if avro_field.get("logicalType") == "decimal": if "precision" not in avro_field: raise ValueError( f"{field_name} decimal type does not have a required field precision" @@ -151,18 +153,16 @@ def _convert_avro_type_to_json( # For example: ^-?\d{1,5}(?:\.\d{1,3})?$ would accept 12345.123 and 123456.12345 would be rejected return { "type": "string", - "pattern": f"^-?\\d{{{1,max_whole_number_range}}}(?:\\.\\d{1,decimal_range})?$", + "pattern": f"^-?\\d{{{1, max_whole_number_range}}}(?:\\.\\d{1, decimal_range})?$", } - elif "logicalType" in avro_field: + if "logicalType" in avro_field: if avro_field["logicalType"] not in AVRO_LOGICAL_TYPE_TO_JSON: raise ValueError( f"{avro_field['logicalType']} is not a valid Avro logical type" ) return AVRO_LOGICAL_TYPE_TO_JSON[avro_field["logicalType"]] - else: - raise ValueError(f"Unsupported avro type: {avro_field}") - else: raise ValueError(f"Unsupported avro type: {avro_field}") + raise ValueError(f"Unsupported avro type: {avro_field}") def parse_records( self, @@ -170,8 +170,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: avro_format = config.format or AvroFormat(filetype="avro") if not isinstance(avro_format, AvroFormat): raise ValueError(f"Expected ParquetFormat, got {avro_format}") @@ -189,7 +189,7 @@ def parse_records( yield { record_field: self._to_output_value( avro_format, - schema_field_name_to_type[record_field], + record_value, record[record_field], ) for record_field, record_value in schema_field_name_to_type.items() @@ -209,21 +209,20 @@ def _to_output_value( ) -> Any: if isinstance(record_value, bytes): return record_value.decode() - elif not isinstance(record_type, Mapping): + if not isinstance(record_type, Mapping): if record_type == "double" and avro_format.double_as_string: return str(record_value) return record_value if record_type.get("logicalType") in ("decimal", "uuid"): return str(record_value) - elif record_type.get("logicalType") == "date": + if record_type.get("logicalType") == "date": return record_value.isoformat() - elif record_type.get("logicalType") == "timestamp-millis": + if record_type.get("logicalType") == "timestamp-millis": return record_value.isoformat(sep="T", timespec="milliseconds") - elif record_type.get("logicalType") == "timestamp-micros": + if record_type.get("logicalType") == "timestamp-micros": return record_value.isoformat(sep="T", timespec="microseconds") - elif record_type.get("logicalType") == "local-timestamp-millis": + if record_type.get("logicalType") == "local-timestamp-millis": return record_value.isoformat(sep="T", timespec="milliseconds") - elif record_type.get("logicalType") == "local-timestamp-micros": + if record_type.get("logicalType") == "local-timestamp-micros": return record_value.isoformat(sep="T", timespec="microseconds") - else: - return record_value + return record_value diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py index 951be6fe..013a2c33 100644 --- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -1,17 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import csv import json import logging from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping from functools import partial from io import IOBase -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple +from typing import Any from uuid import uuid4 +from orjson import orjson + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.csv_format import ( CsvFormat, @@ -29,7 +33,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING, SchemaType from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from orjson import orjson + DIALECT_NAME = "_config_dialect" @@ -42,7 +46,7 @@ def read_data( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, file_read_mode: FileReadMode, - ) -> Generator[Dict[str, Any], None, None]: + ) -> Generator[dict[str, Any], None, None]: config_format = _extract_format(config) lineno = 0 @@ -51,7 +55,7 @@ def read_data( # Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up # with a race condition where a thread attempts to use a dialect before a separate thread has finished # registering it. - dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}" + dialect_name = f"{config.name}_{uuid4()!s}_{DIALECT_NAME}" csv.register_dialect( dialect_name, delimiter=config_format.delimiter, @@ -110,10 +114,8 @@ def read_data( # due to RecordParseError or GeneratorExit csv.unregister_dialect(dialect_name) - def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]: - """ - Assumes the fp is pointing to the beginning of the files and will reset it as such - """ + def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> list[str]: + """Assumes the fp is pointing to the beginning of the files and will reset it as such""" # Note that this method assumes the dialect has already been registered if we're parsing the headers if isinstance(config_format.header_definition, CsvHeaderUserProvided): return config_format.header_definition.column_names # type: ignore # should be CsvHeaderUserProvided given the type @@ -132,9 +134,8 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) fp.seek(0) return headers - def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]: - """ - Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True. + def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> list[str]: + """Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True. See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html """ reader = csv.reader(fp, dialect=dialect_name) # type: ignore @@ -143,9 +144,7 @@ def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]: @staticmethod def _skip_rows(fp: IOBase, rows_to_skip: int) -> None: - """ - Skip rows before the header. This has to be done on the file object itself, not the reader - """ + """Skip rows before the header. This has to be done on the file object itself, not the reader""" for _ in range(rows_to_skip): fp.readline() @@ -153,17 +152,15 @@ def _skip_rows(fp: IOBase, rows_to_skip: int) -> None: class CsvParser(FileTypeParser): _MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000 - def __init__(self, csv_reader: Optional[_CsvReader] = None, csv_field_max_bytes: int = 2**31): + def __init__(self, csv_reader: _CsvReader | None = None, csv_field_max_bytes: int = 2**31): # Increase the maximum length of data that can be parsed in a single CSV field. The default is 128k, which is typically sufficient # but given the use of Airbyte in loading a large variety of data it is best to allow for a larger maximum field size to avoid # skipping data on load. https://stackoverflow.com/questions/15063936/csv-error-field-larger-than-field-limit-131072 csv.field_size_limit(csv_field_max_bytes) - self._csv_reader = csv_reader if csv_reader else _CsvReader() + self._csv_reader = csv_reader or _CsvReader() - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - CsvParser does not require config checks, implicit pydantic validation is enough. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """CsvParser does not require config checks, implicit pydantic validation is enough.""" return True, None async def infer_schema( @@ -177,10 +174,10 @@ async def infer_schema( if input_schema: return input_schema - # todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual + # TODO: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual # sources will likely require one. Rather than modify the interface now we can wait until the real use case config_format = _extract_format(config) - type_inferrer_by_field: Dict[str, _TypeInferrer] = defaultdict( + type_inferrer_by_field: dict[str, _TypeInferrer] = defaultdict( lambda: _JsonTypeInferrer( config_format.true_values, config_format.false_values, config_format.null_values ) @@ -220,8 +217,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: line_no = 0 try: config_format = _extract_format(config) @@ -272,17 +269,16 @@ def _get_cast_function( config_format=config_format, logger=logger, ) - else: - # If no schema is provided, yield the rows as they are - return _no_cast + # If no schema is provided, yield the rows as they are + return _no_cast @staticmethod def _to_nullable( row: Mapping[str, str], deduped_property_types: Mapping[str, str], - null_values: Set[str], + null_values: set[str], strings_can_be_null: bool, - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: nullable = { k: None if CsvParser._value_is_none( @@ -296,16 +292,15 @@ def _to_nullable( @staticmethod def _value_is_none( value: Any, - deduped_property_type: Optional[str], - null_values: Set[str], + deduped_property_type: str | None, + null_values: set[str], strings_can_be_null: bool, ) -> bool: return value in null_values and (strings_can_be_null or deduped_property_type != "string") @staticmethod - def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, str]: - """ - Transform the property types to be non-nullable and remove duplicate types if any. + def _pre_propcess_property_types(property_types: dict[str, Any]) -> Mapping[str, str]: + """Transform the property types to be non-nullable and remove duplicate types if any. Sample input: { "col1": ["string", "null"], @@ -334,13 +329,12 @@ def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, @staticmethod def _cast_types( - row: Dict[str, str], + row: dict[str, str], deduped_property_types: Mapping[str, str], config_format: CsvFormat, logger: logging.Logger, - ) -> Dict[str, Any]: - """ - Casts the values in the input 'row' dictionary according to the types defined in the JSON schema. + ) -> dict[str, Any]: + """Casts the values in the input 'row' dictionary according to the types defined in the JSON schema. Array and object types are only handled if they can be deserialized as JSON. @@ -424,12 +418,12 @@ class _JsonTypeInferrer(_TypeInferrer): _STRING_TYPE = "string" def __init__( - self, boolean_trues: Set[str], boolean_falses: Set[str], null_values: Set[str] + self, boolean_trues: set[str], boolean_falses: set[str], null_values: set[str] ) -> None: self._boolean_trues = boolean_trues self._boolean_falses = boolean_falses self._null_values = null_values - self._values: Set[str] = set() + self._values: set[str] = set() def add_value(self, value: Any) -> None: self._values.add(value) @@ -446,13 +440,13 @@ def infer(self) -> str: types = set.intersection(*types_excluding_null_values) if self._BOOLEAN_TYPE in types: return self._BOOLEAN_TYPE - elif self._INTEGER_TYPE in types: + if self._INTEGER_TYPE in types: return self._INTEGER_TYPE - elif self._NUMBER_TYPE in types: + if self._NUMBER_TYPE in types: return self._NUMBER_TYPE return self._STRING_TYPE - def _infer_type(self, value: str) -> Set[str]: + def _infer_type(self, value: str) -> set[str]: inferred_types = set() if value in self._null_values: @@ -492,7 +486,7 @@ def _is_number(value: str) -> bool: return False -def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool: +def _value_to_bool(value: str, true_values: set[str], false_values: set[str]) -> bool: if value in true_values: return True if value in false_values: @@ -500,7 +494,7 @@ def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> raise ValueError(f"Value {value} is not a valid boolean value") -def _value_to_list(value: str) -> List[Any]: +def _value_to_list(value: str) -> list[Any]: parsed_value = json.loads(value) if isinstance(parsed_value, list): return parsed_value @@ -511,7 +505,7 @@ def _value_to_python_type(value: str, python_type: type) -> Any: return python_type(value) -def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str: +def _format_warning(key: str, value: str, expected_type: Any | None) -> str: return f"{key}: value={value},expected_type={expected_type}" diff --git a/airbyte_cdk/sources/file_based/file_types/excel_parser.py b/airbyte_cdk/sources/file_based/file_types/excel_parser.py index 0c0da8b3..6e129e16 100644 --- a/airbyte_cdk/sources/file_based/file_types/excel_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/excel_parser.py @@ -1,13 +1,20 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, Mapping from io import IOBase from pathlib import Path -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import Any import pandas as pd +from numpy import datetime64, issubdtype +from numpy import dtype as dtype_ +from orjson import orjson +from pydantic.v1 import BaseModel + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( ExcelFormat, FileBasedStreamConfig, @@ -24,20 +31,13 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType -from numpy import datetime64 -from numpy import dtype as dtype_ -from numpy import issubdtype -from orjson import orjson -from pydantic.v1 import BaseModel class ExcelParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - ExcelParser does not require config checks, implicit pydantic validation is enough. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """ExcelParser does not require config checks, implicit pydantic validation is enough.""" return True, None async def infer_schema( @@ -47,8 +47,7 @@ async def infer_schema( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, ) -> SchemaType: - """ - Infers the schema of the Excel file by examining its contents. + """Infers the schema of the Excel file by examining its contents. Args: config (FileBasedStreamConfig): Configuration for the file-based stream. @@ -59,11 +58,10 @@ async def infer_schema( Returns: SchemaType: Inferred schema of the Excel file. """ - # Validate the format of the config self.validate_format(config.format, logger) - fields: Dict[str, str] = {} + fields: dict[str, str] = {} with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: df = self.open_and_parse_file(fp) @@ -88,10 +86,9 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]] = None, - ) -> Iterable[Dict[str, Any]]: - """ - Parses records from an Excel file based on the provided configuration. + discovered_schema: Mapping[str, SchemaType] | None = None, + ) -> Iterable[dict[str, Any]]: + """Parses records from an Excel file based on the provided configuration. Args: config (FileBasedStreamConfig): Configuration for the file-based stream. @@ -103,7 +100,6 @@ def parse_records( Yields: Iterable[Dict[str, Any]]: Parsed records from the Excel file. """ - # Validate the format of the config self.validate_format(config.format, logger) @@ -127,8 +123,7 @@ def parse_records( @property def file_read_mode(self) -> FileReadMode: - """ - Returns the file read mode for the Excel file. + """Returns the file read mode for the Excel file. Returns: FileReadMode: The file read mode (binary). @@ -136,9 +131,8 @@ def file_read_mode(self) -> FileReadMode: return FileReadMode.READ_BINARY @staticmethod - def dtype_to_json_type(current_type: Optional[str], dtype: dtype_) -> str: - """ - Convert Pandas DataFrame types to Airbyte Types. + def dtype_to_json_type(current_type: str | None, dtype: dtype_) -> str: + """Convert Pandas DataFrame types to Airbyte Types. Args: current_type (Optional[str]): One of the previous types based on earlier dataframes. @@ -163,8 +157,7 @@ def dtype_to_json_type(current_type: Optional[str], dtype: dtype_) -> str: @staticmethod def validate_format(excel_format: BaseModel, logger: logging.Logger) -> None: - """ - Validates if the given format is of type ExcelFormat. + """Validates if the given format is of type ExcelFormat. Args: excel_format (Any): The format to be validated. @@ -177,9 +170,8 @@ def validate_format(excel_format: BaseModel, logger: logging.Logger) -> None: raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) @staticmethod - def open_and_parse_file(fp: Union[IOBase, str, Path]) -> pd.DataFrame: - """ - Opens and parses the Excel file. + def open_and_parse_file(fp: IOBase | str | Path) -> pd.DataFrame: + """Opens and parses the Excel file. Args: fp: File pointer to the Excel file. diff --git a/airbyte_cdk/sources/file_based/file_types/file_transfer.py b/airbyte_cdk/sources/file_based/file_types/file_transfer.py index 154b6ff4..5dbf7d79 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_transfer.py +++ b/airbyte_cdk/sources/file_based/file_types/file_transfer.py @@ -1,14 +1,18 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import os -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.remote_file import RemoteFile + AIRBYTE_STAGING_DIRECTORY = os.getenv("AIRBYTE_STAGING_DIRECTORY", "/staging/files") DEFAULT_LOCAL_DIRECTORY = "/tmp/airbyte-file-transfer" @@ -27,7 +31,7 @@ def get_file( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - ) -> Iterable[Dict[str, Any]]: + ) -> Iterable[dict[str, Any]]: try: yield stream_reader.get_file( file=file, local_directory=self._local_directory, logger=logger diff --git a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py index e6a9c5cb..f6f1825a 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.file_based_stream_reader import ( @@ -14,40 +16,32 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType -Record = Dict[str, Any] + +Record = dict[str, Any] class FileTypeParser(ABC): - """ - An abstract class containing methods that must be implemented for each + """An abstract class containing methods that must be implemented for each supported file type. """ @property - def parser_max_n_files_for_schema_inference(self) -> Optional[int]: - """ - The discovery policy decides how many files are loaded for schema inference. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used. - """ + def parser_max_n_files_for_schema_inference(self) -> int | None: + """The discovery policy decides how many files are loaded for schema inference. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used.""" return None @property - def parser_max_n_files_for_parsability(self) -> Optional[int]: - """ - The availability policy decides how many files are loaded for checking whether parsing works correctly. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used. - """ + def parser_max_n_files_for_parsability(self) -> int | None: + """The availability policy decides how many files are loaded for checking whether parsing works correctly. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used.""" return None - def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]: - """ - The parser can define a primary key. If no user-defined primary key is provided, this will be used. - """ + def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: + """The parser can define a primary key. If no user-defined primary key is provided, this will be used.""" return None @abstractmethod - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - Check whether the config is valid for this file type. If it is, return True and None. If it's not, return False and an error message explaining why it's invalid. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """Check whether the config is valid for this file type. If it is, return True and None. If it's not, return False and an error message explaining why it's invalid.""" return True, None @abstractmethod @@ -58,9 +52,7 @@ async def infer_schema( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, ) -> SchemaType: - """ - Infer the JSON Schema for this file. - """ + """Infer the JSON Schema for this file.""" ... @abstractmethod @@ -70,17 +62,13 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], + discovered_schema: Mapping[str, SchemaType] | None, ) -> Iterable[Record]: - """ - Parse and emit each record. - """ + """Parse and emit each record.""" ... @property @abstractmethod def file_read_mode(self) -> FileReadMode: - """ - The mode in which the file should be opened for reading. - """ + """The mode in which the file should be opened for reading.""" ... diff --git a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py index 6cd59075..ce8a0b7d 100644 --- a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py @@ -1,10 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any + +from orjson import orjson from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError @@ -19,17 +23,14 @@ SchemaType, merge_schemas, ) -from orjson import orjson class JsonlParser(FileTypeParser): MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000 ENCODING = "utf8" - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - JsonlParser does not require config checks, implicit pydantic validation is enough. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """JsonlParser does not require config checks, implicit pydantic validation is enough.""" return True, None async def infer_schema( @@ -39,8 +40,7 @@ async def infer_schema( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, ) -> SchemaType: - """ - Infers the schema for the file by inferring the schema for each line, and merging + """Infers the schema for the file by inferring the schema for each line, and merging it with the previously-inferred schema. """ inferred_schema: Mapping[str, Any] = {} @@ -57,10 +57,9 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: - """ - This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: + """This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for backward compatibility reasons i.e. the previous source-s3 parser did support this. The drawback is: * performance as the way we support json over multiple lines is very brute forced * given that we don't have `newlines_in_values` config to scope the possible inputs, we might parse the whole file before knowing if @@ -72,7 +71,7 @@ def parse_records( yield from self._parse_jsonl_entries(file, stream_reader, logger) @classmethod - def _infer_schema_for_record(cls, record: Dict[str, Any]) -> Dict[str, Any]: + def _infer_schema_for_record(cls, record: dict[str, Any]) -> dict[str, Any]: record_schema = {} for key, value in record.items(): if value is None: @@ -92,7 +91,7 @@ def _parse_jsonl_entries( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, read_limit: bool = False, - ) -> Iterable[Dict[str, Any]]: + ) -> Iterable[dict[str, Any]]: with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: read_bytes = 0 @@ -137,8 +136,8 @@ def _parse_jsonl_entries( ) @staticmethod - def _instantiate_accumulator(line: Union[bytes, str]) -> Union[bytes, str]: + def _instantiate_accumulator(line: bytes | str) -> bytes | str: if isinstance(line, bytes): return bytes("", json.detect_encoding(line)) - elif isinstance(line, str): + if isinstance(line, str): return "" diff --git a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py index 99b6373d..e916b3bb 100644 --- a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging import os -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any from urllib.parse import unquote import pyarrow as pa import pyarrow.parquet as pq +from pyarrow import DictionaryArray, Scalar + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( FileBasedStreamConfig, ParquetFormat, @@ -26,16 +30,13 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType -from pyarrow import DictionaryArray, Scalar class ParquetParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - ParquetParser does not require config checks, implicit pydantic validation is enough. - """ + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """ParquetParser does not require config checks, implicit pydantic validation is enough.""" return True, None async def infer_schema( @@ -73,8 +74,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: parquet_format = config.format if not isinstance(parquet_format, ParquetFormat): logger.info(f"Expected ParquetFormat, got {parquet_format}") @@ -108,7 +109,7 @@ def parse_records( ) from exc @staticmethod - def _extract_partitions(filepath: str) -> List[str]: + def _extract_partitions(filepath: str) -> list[str]: return [unquote(partition) for partition in filepath.split(os.sep) if "=" in partition] @property @@ -117,21 +118,16 @@ def file_read_mode(self) -> FileReadMode: @staticmethod def _to_output_value( - parquet_value: Union[Scalar, DictionaryArray], parquet_format: ParquetFormat + parquet_value: Scalar | DictionaryArray, parquet_format: ParquetFormat ) -> Any: - """ - Convert an entry in a pyarrow table to a value that can be output by the source. - """ + """Convert an entry in a pyarrow table to a value that can be output by the source.""" if isinstance(parquet_value, DictionaryArray): return ParquetParser._dictionary_array_to_python_value(parquet_value) - else: - return ParquetParser._scalar_to_python_value(parquet_value, parquet_format) + return ParquetParser._scalar_to_python_value(parquet_value, parquet_format) @staticmethod def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat) -> Any: - """ - Convert a pyarrow scalar to a value that can be output by the source. - """ + """Convert a pyarrow scalar to a value that can be output by the source.""" if parquet_value.as_py() is None: return None @@ -154,8 +150,7 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat if pa.types.is_decimal(parquet_value.type): if parquet_format.decimal_as_float: return float(parquet_value.as_py()) - else: - return str(parquet_value.as_py()) + return str(parquet_value.as_py()) if pa.types.is_map(parquet_value.type): return {k: v for k, v in parquet_value.as_py()} @@ -169,26 +164,22 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat duration_seconds = duration.total_seconds() if parquet_value.type.unit == "s": return duration_seconds - elif parquet_value.type.unit == "ms": + if parquet_value.type.unit == "ms": return duration_seconds * 1000 - elif parquet_value.type.unit == "us": + if parquet_value.type.unit == "us": return duration_seconds * 1_000_000 - elif parquet_value.type.unit == "ns": + if parquet_value.type.unit == "ns": return duration_seconds * 1_000_000_000 + duration.nanoseconds - else: - raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}") - else: - return parquet_value.as_py() + raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}") + return parquet_value.as_py() @staticmethod - def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[str, Any]: - """ - Convert a pyarrow dictionary array to a value that can be output by the source. + def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> dict[str, Any]: + """Convert a pyarrow dictionary array to a value that can be output by the source. Dictionaries are stored as two columns: indices and values The indices column is an array of integers that maps to the values column """ - return { "indices": parquet_value.indices.tolist(), "values": parquet_value.dictionary.tolist(), @@ -198,31 +189,28 @@ def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[st def parquet_type_to_schema_type( parquet_type: pa.DataType, parquet_format: ParquetFormat ) -> Mapping[str, str]: - """ - Convert a pyarrow data type to an Airbyte schema type. + """Convert a pyarrow data type to an Airbyte schema type. Parquet data types are defined at https://arrow.apache.org/docs/python/api/datatypes.html """ - if pa.types.is_timestamp(parquet_type): return {"type": "string", "format": "date-time"} - elif pa.types.is_date(parquet_type): + if pa.types.is_date(parquet_type): return {"type": "string", "format": "date"} - elif ParquetParser._is_string(parquet_type, parquet_format): + if ParquetParser._is_string(parquet_type, parquet_format): return {"type": "string"} - elif pa.types.is_boolean(parquet_type): + if pa.types.is_boolean(parquet_type): return {"type": "boolean"} - elif ParquetParser._is_integer(parquet_type): + if ParquetParser._is_integer(parquet_type): return {"type": "integer"} - elif ParquetParser._is_float(parquet_type, parquet_format): + if ParquetParser._is_float(parquet_type, parquet_format): return {"type": "number"} - elif ParquetParser._is_object(parquet_type): + if ParquetParser._is_object(parquet_type): return {"type": "object"} - elif ParquetParser._is_list(parquet_type): + if ParquetParser._is_list(parquet_type): return {"type": "array"} - elif pa.types.is_null(parquet_type): + if pa.types.is_null(parquet_type): return {"type": "null"} - else: - raise ValueError(f"Unsupported parquet type: {parquet_type}") + raise ValueError(f"Unsupported parquet type: {parquet_type}") @staticmethod def _is_binary(parquet_type: pa.DataType) -> bool: @@ -240,22 +228,20 @@ def _is_integer(parquet_type: pa.DataType) -> bool: def _is_float(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool: if pa.types.is_decimal(parquet_type): return parquet_format.decimal_as_float - else: - return bool(pa.types.is_floating(parquet_type)) + return bool(pa.types.is_floating(parquet_type)) @staticmethod def _is_string(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool: if pa.types.is_decimal(parquet_type): return not parquet_format.decimal_as_float - else: - return bool( - pa.types.is_time(parquet_type) - or pa.types.is_string(parquet_type) - or pa.types.is_large_string(parquet_type) - or ParquetParser._is_binary( - parquet_type - ) # Best we can do is return as a string since we do not support binary - ) + return bool( + pa.types.is_time(parquet_type) + or pa.types.is_string(parquet_type) + or pa.types.is_large_string(parquet_type) + or ParquetParser._is_binary( + parquet_type + ) # Best we can do is return as a string since we do not support binary + ) @staticmethod def _is_object(parquet_type: pa.DataType) -> bool: diff --git a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py index e397ceae..401b3f00 100644 --- a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py @@ -1,15 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import traceback +from collections.abc import Iterable, Mapping from datetime import datetime from io import BytesIO, IOBase -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any import backoff import dpath import requests +from unstructured.file_utils.filetype import ( + FILETYPE_TO_MIMETYPE, + STR_TO_FILETYPE, + FileType, + detect_filetype, +) + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.config.unstructured_format import ( @@ -28,19 +38,14 @@ from airbyte_cdk.sources.file_based.schema_helpers import SchemaType from airbyte_cdk.utils import is_cloud_environment from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unstructured.file_utils.filetype import ( - FILETYPE_TO_MIMETYPE, - STR_TO_FILETYPE, - FileType, - detect_filetype, -) + unstructured_partition_pdf = None unstructured_partition_docx = None unstructured_partition_pptx = None -def optional_decode(contents: Union[str, bytes]) -> str: +def optional_decode(contents: str | bytes) -> str: if isinstance(contents, bytes): return contents.decode("utf-8") return contents @@ -48,9 +53,7 @@ def optional_decode(contents: Union[str, bytes]) -> str: def _import_unstructured() -> None: """Dynamically imported as needed, due to slow import speed.""" - global unstructured_partition_pdf - global unstructured_partition_docx - global unstructured_partition_pptx + global unstructured_partition_pdf, unstructured_partition_docx, unstructured_partition_pptx from unstructured.partition.docx import partition_docx from unstructured.partition.pdf import partition_pdf from unstructured.partition.pptx import partition_pptx @@ -62,9 +65,7 @@ def _import_unstructured() -> None: def user_error(e: Exception) -> bool: - """ - Return True if this exception is caused by user error, False otherwise. - """ + """Return True if this exception is caused by user error, False otherwise.""" if not isinstance(e, RecordParseError): return False if not isinstance(e, requests.exceptions.RequestException): @@ -77,22 +78,17 @@ def user_error(e: Exception) -> bool: class UnstructuredParser(FileTypeParser): @property - def parser_max_n_files_for_schema_inference(self) -> Optional[int]: - """ - Just check one file as the schema is static - """ + def parser_max_n_files_for_schema_inference(self) -> int | None: + """Just check one file as the schema is static""" return 1 @property - def parser_max_n_files_for_parsability(self) -> Optional[int]: - """ - Do not check any files for parsability because it might be an expensive operation and doesn't give much confidence whether the sync will succeed. - """ + def parser_max_n_files_for_parsability(self) -> int | None: + """Do not check any files for parsability because it might be an expensive operation and doesn't give much confidence whether the sync will succeed.""" return 0 - def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]: - """ - Return the document_key field as the primary key. + def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: + """Return the document_key field as the primary key. his will pre-select the document key column as the primary key when setting up a connection, making it easier for the user to configure normalization in the destination. """ @@ -133,8 +129,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: format = _extract_format(config) with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle: try: @@ -186,7 +182,7 @@ def _read_file( raise self._create_parse_error(remote_file, self._get_file_type_error_message(filetype)) if format.processing.mode == "local": return self._read_file_locally(file_handle, filetype, format.strategy, remote_file) - elif format.processing.mode == "api": + if format.processing.mode == "api": try: result: str = self._read_file_remotely_with_retries( file_handle, format.processing, filetype, format.strategy, remote_file @@ -205,9 +201,9 @@ def _read_file( return result def _params_to_dict( - self, params: Optional[List[APIParameterConfigModel]], strategy: str - ) -> Dict[str, Union[str, List[str]]]: - result_dict: Dict[str, Union[str, List[str]]] = {"strategy": strategy} + self, params: list[APIParameterConfigModel] | None, strategy: str + ) -> dict[str, str | list[str]]: + result_dict: dict[str, str | list[str]] = {"strategy": strategy} if params is None: return result_dict for item in params: @@ -226,9 +222,8 @@ def _params_to_dict( return result_dict - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: - """ - Perform a connection check for the parser config: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: + """Perform a connection check for the parser config: - Verify that encryption is enabled if the API is hosted on a cloud instance. - Verify that the API can extract text from a file. @@ -267,9 +262,7 @@ def _read_file_remotely_with_retries( strategy: str, remote_file: RemoteFile, ) -> str: - """ - Read a file remotely, retrying up to 5 times if the error is not caused by user error. This is useful for transient network errors or the API server being overloaded temporarily. - """ + """Read a file remotely, retrying up to 5 times if the error is not caused by user error. This is useful for transient network errors or the API server being overloaded temporarily.""" return self._read_file_remotely(file_handle, format, filetype, strategy, remote_file) def _read_file_remotely( @@ -293,9 +286,8 @@ def _read_file_remotely( if response.status_code == 422: # 422 means the file couldn't be processed, but the API is working. Treat this as a parsing error (passing an error record to the destination). raise self._create_parse_error(remote_file, response.json()) - else: - # Other error statuses are raised as requests exceptions (retry everything except user errors) - response.raise_for_status() + # Other error statuses are raised as requests exceptions (retry everything except user errors) + response.raise_for_status() json_response = response.json() @@ -341,9 +333,8 @@ def _create_parse_error(self, remote_file: RemoteFile, message: str) -> RecordPa FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri, message=message ) - def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]: - """ - Detect the file type based on the file name and the file content. + def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> FileType | None: + """Detect the file type based on the file name and the file content. There are three strategies to determine the file type: 1. Use the mime type if available (only some sources support it) @@ -363,7 +354,7 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT file_type = detect_filetype( filename=remote_file.uri, ) - if file_type is not None and not file_type == FileType.UNK: + if file_type is not None and file_type != FileType.UNK: return file_type type_based_on_content = detect_filetype(file=file) @@ -373,26 +364,25 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT return type_based_on_content - def _supported_file_types(self) -> List[Any]: + def _supported_file_types(self) -> list[Any]: return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX, FileType.TXT] def _get_file_type_error_message(self, file_type: FileType) -> str: supported_file_types = ", ".join([str(type) for type in self._supported_file_types()]) return f"File type {file_type} is not supported. Supported file types are {supported_file_types}" - def _render_markdown(self, elements: List[Any]) -> str: - return "\n\n".join((self._convert_to_markdown(el) for el in elements)) + def _render_markdown(self, elements: list[Any]) -> str: + return "\n\n".join(self._convert_to_markdown(el) for el in elements) - def _convert_to_markdown(self, el: Dict[str, Any]) -> str: + def _convert_to_markdown(self, el: dict[str, Any]) -> str: if dpath.get(el, "type") == "Title": heading_str = "#" * (dpath.get(el, "metadata/category_depth", default=1) or 1) return f"{heading_str} {dpath.get(el, 'text')}" - elif dpath.get(el, "type") == "ListItem": + if dpath.get(el, "type") == "ListItem": return f"- {dpath.get(el, 'text')}" - elif dpath.get(el, "type") == "Formula": + if dpath.get(el, "type") == "Formula": return f"```\n{dpath.get(el, 'text')}\n```" - else: - return str(dpath.get(el, "text", default="")) + return str(dpath.get(el, "text", default="")) @property def file_read_mode(self) -> FileReadMode: diff --git a/airbyte_cdk/sources/file_based/remote_file.py b/airbyte_cdk/sources/file_based/remote_file.py index 0197a35f..9bc6ea02 100644 --- a/airbyte_cdk/sources/file_based/remote_file.py +++ b/airbyte_cdk/sources/file_based/remote_file.py @@ -1,18 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from datetime import datetime -from typing import Optional from pydantic.v1 import BaseModel class RemoteFile(BaseModel): - """ - A file in a file-based stream. - """ + """A file in a file-based stream.""" uri: str last_modified: datetime - mime_type: Optional[str] = None + mime_type: str | None = None diff --git a/airbyte_cdk/sources/file_based/schema_helpers.py b/airbyte_cdk/sources/file_based/schema_helpers.py index 1b653db6..442a52d7 100644 --- a/airbyte_cdk/sources/file_based/schema_helpers.py +++ b/airbyte_cdk/sources/file_based/schema_helpers.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json +from collections.abc import Mapping from copy import deepcopy from enum import Enum from functools import total_ordering -from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union +from typing import Any, Literal, Union from airbyte_cdk.sources.file_based.exceptions import ( ConfigValidationError, @@ -14,7 +16,8 @@ SchemaInferenceError, ) -JsonSchemaSupportedType = Union[List[str], Literal["string"], str] + +JsonSchemaSupportedType = Union[list[str], Literal["string"], str] SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]] schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}} @@ -36,11 +39,10 @@ class ComparableType(Enum): def __lt__(self, other: Any) -> bool: if self.__class__ is other.__class__: return self.value < other.value # type: ignore - else: - return NotImplemented + return NotImplemented -TYPE_PYTHON_MAPPING: Mapping[str, Tuple[str, Optional[Type[Any]]]] = { +TYPE_PYTHON_MAPPING: Mapping[str, tuple[str, type[Any] | None]] = { "null": ("null", None), "array": ("array", list), "boolean": ("boolean", bool), @@ -53,7 +55,7 @@ def __lt__(self, other: Any) -> bool: PYTHON_TYPE_MAPPING = {t: k for k, (_, t) in TYPE_PYTHON_MAPPING.items()} -def get_comparable_type(value: Any) -> Optional[ComparableType]: +def get_comparable_type(value: Any) -> ComparableType | None: if value == "null": return ComparableType.NULL if value == "boolean": @@ -66,11 +68,10 @@ def get_comparable_type(value: Any) -> Optional[ComparableType]: return ComparableType.STRING if value == "object": return ComparableType.OBJECT - else: - return None + return None -def get_inferred_type(value: Any) -> Optional[ComparableType]: +def get_inferred_type(value: Any) -> ComparableType | None: if value is None: return ComparableType.NULL if isinstance(value, bool): @@ -83,13 +84,11 @@ def get_inferred_type(value: Any) -> Optional[ComparableType]: return ComparableType.STRING if isinstance(value, dict): return ComparableType.OBJECT - else: - return None + return None def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType: - """ - Returns a new dictionary that contains schema1 and schema2. + """Returns a new dictionary that contains schema1 and schema2. Schemas are merged as follows - If a key is in one schema but not the other, add it to the base schema with its existing type. @@ -107,7 +106,7 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType: if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]): raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t) - merged_schema: Dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping + merged_schema: dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping for k2, t2 in schema2.items(): t1 = merged_schema.get(k2) if t1 is None: @@ -136,7 +135,7 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - detected_types=f"{t1},{t2}", ) # Schemas can still be merged if a key contains a null value in either t1 or t2, but it is still an object - elif ( + if ( (t1_type == "object" or t2_type == "object") and t1_type != "null" and t2_type != "null" @@ -148,21 +147,20 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - key=key, detected_types=f"{t1},{t2}", ) - else: - comparable_t1 = get_comparable_type( - TYPE_PYTHON_MAPPING[t1_type][0] - ) # accessing the type_mapping value - comparable_t2 = get_comparable_type( - TYPE_PYTHON_MAPPING[t2_type][0] - ) # accessing the type_mapping value - if not comparable_t1 and comparable_t2: - raise SchemaInferenceError( - FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}" - ) - return max( - [t1, t2], - key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])), - ) # accessing the type_mapping value + comparable_t1 = get_comparable_type( + TYPE_PYTHON_MAPPING[t1_type][0] + ) # accessing the type_mapping value + comparable_t2 = get_comparable_type( + TYPE_PYTHON_MAPPING[t2_type][0] + ) # accessing the type_mapping value + if not comparable_t1 and comparable_t2: + raise SchemaInferenceError( + FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}" + ) + return max( + [t1, t2], + key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])), + ) # accessing the type_mapping value def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: @@ -181,8 +179,7 @@ def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool: - """ - Return true iff the record conforms to the supplied schema. + """Return true iff the record conforms to the supplied schema. The record conforms to the supplied schema iff: - All columns in the record are in the schema. @@ -202,9 +199,9 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> if value is not None: if isinstance(expected_type, list): return any(is_equal_or_narrower_type(value, e) for e in expected_type) - elif expected_type == "object": + if expected_type == "object": return isinstance(value, dict) - elif expected_type == "array": + if expected_type == "array": if not isinstance(value, list): return False array_type = definition.get("items", {}).get("type") @@ -216,7 +213,7 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> return True -def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[Mapping[str, str]]: +def _parse_json_input(input_schema: str | Mapping[str, str]) -> Mapping[str, str] | None: try: if isinstance(input_schema, str): schema: Mapping[str, str] = json.loads(input_schema) @@ -235,10 +232,9 @@ def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[M def type_mapping_to_jsonschema( - input_schema: Optional[Union[str, Mapping[str, str]]], -) -> Optional[Mapping[str, Any]]: - """ - Return the user input schema (type mapping), transformed to JSON Schema format. + input_schema: str | Mapping[str, str] | None, +) -> Mapping[str, Any] | None: + """Return the user input schema (type mapping), transformed to JSON Schema format. Verify that the input schema: - is a key:value map diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py index 139511a9..dcfde911 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any class AbstractSchemaValidationPolicy(ABC): @@ -12,9 +14,7 @@ class AbstractSchemaValidationPolicy(ABC): @abstractmethod def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: - """ - Return True if the record passes the user's validation policy. - """ - raise NotImplementedError() + """Return True if the record passes the user's validation policy.""" + raise NotImplementedError diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py index 261b0fab..bd1e3a7f 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy from airbyte_cdk.sources.file_based.exceptions import ( @@ -17,7 +19,7 @@ class EmitRecordPolicy(AbstractSchemaValidationPolicy): name = "emit_record" def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: return True @@ -26,7 +28,7 @@ class SkipRecordPolicy(AbstractSchemaValidationPolicy): name = "skip_record" def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: return schema is not None and conforms_to_schema(record, schema) @@ -36,7 +38,7 @@ class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy): validate_schema_before_sync = True def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: if schema is None or not conforms_to_schema(record, schema): raise StopSyncPerValidationPolicy( diff --git a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py index 8c0e1ebf..bf492c90 100644 --- a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py @@ -1,10 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod -from functools import cache, cached_property, lru_cache -from typing import Any, Dict, Iterable, List, Mapping, Optional, Type +from collections.abc import Iterable, Mapping +from functools import cache, cached_property +from typing import Any + +from deprecated import deprecated from airbyte_cdk import AirbyteMessage from airbyte_cdk.models import SyncMode @@ -30,12 +34,10 @@ from airbyte_cdk.sources.file_based.types import StreamSlice from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.checkpoint import Cursor -from deprecated import deprecated class AbstractFileBasedStream(Stream): - """ - A file-based stream in an Airbyte source. + """A file-based stream in an Airbyte source. In addition to the base Stream attributes, a file-based stream has - A config object (derived from the corresponding stream section in source config). @@ -52,11 +54,11 @@ class AbstractFileBasedStream(Stream): def __init__( self, config: FileBasedStreamConfig, - catalog_schema: Optional[Mapping[str, Any]], + catalog_schema: Mapping[str, Any] | None, stream_reader: AbstractFileBasedStreamReader, availability_strategy: AbstractFileBasedAvailabilityStrategy, discovery_policy: AbstractDiscoveryPolicy, - parsers: Dict[Type[Any], FileTypeParser], + parsers: dict[type[Any], FileTypeParser], validation_policy: AbstractSchemaValidationPolicy, errors_collector: FileBasedErrorsCollector, cursor: AbstractFileBasedCursor, @@ -77,9 +79,8 @@ def __init__( def primary_key(self) -> PrimaryKeyType: ... @cache - def list_files(self) -> List[RemoteFile]: - """ - List all files that belong to the stream. + def list_files(self) -> list[RemoteFile]: + """List all files that belong to the stream. The output of this method is cached so we don't need to list the files more than once. This means we won't pick up changes to the files during a sync. This method uses the @@ -89,20 +90,17 @@ def list_files(self) -> List[RemoteFile]: @abstractmethod def get_files(self) -> Iterable[RemoteFile]: - """ - List all files that belong to the stream as defined by the stream's globs. - """ + """List all files that belong to the stream as defined by the stream's globs.""" ... def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[StreamSlice] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: StreamSlice | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any] | AirbyteMessage]: - """ - Yield all records from all remote files in `list_files_for_this_sync`. + """Yield all records from all remote files in `list_files_for_this_sync`. This method acts as an adapter between the generic Stream interface and the file-based's stream since file-based streams manage their own states. """ @@ -114,45 +112,37 @@ def read_records( def read_records_from_slice( self, stream_slice: StreamSlice ) -> Iterable[Mapping[str, Any] | AirbyteMessage]: - """ - Yield all records from all remote files in `list_files_for_this_sync`. - """ + """Yield all records from all remote files in `list_files_for_this_sync`.""" ... def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: - """ - This method acts as an adapter between the generic Stream interface and the file-based's + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: + """This method acts as an adapter between the generic Stream interface and the file-based's stream since file-based streams manage their own states. """ return self.compute_slices() @abstractmethod - def compute_slices(self) -> Iterable[Optional[StreamSlice]]: - """ - Return a list of slices that will be used to read files in the current sync. + def compute_slices(self) -> Iterable[StreamSlice | None]: + """Return a list of slices that will be used to read files in the current sync. :return: The slices to use for the current sync. """ ... @abstractmethod - @lru_cache(maxsize=None) + @cache def get_json_schema(self) -> Mapping[str, Any]: - """ - Return the JSON Schema for a stream. - """ + """Return the JSON Schema for a stream.""" ... @abstractmethod - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: - """ - Infer the schema for files in the stream. - """ + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: + """Infer the schema for files in the stream.""" ... def get_parser(self) -> FileTypeParser: @@ -170,12 +160,11 @@ def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool: return self.validation_policy.record_passes_validation_policy( record=record, schema=self.catalog_schema ) - else: - raise RecordParseError( - FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, - stream=self.name, - validation_policy=self.config.validation_policy, - ) + raise RecordParseError( + FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, + stream=self.name, + validation_policy=self.config.validation_policy, + ) @cached_property @deprecated(version="3.7.0") @@ -186,9 +175,8 @@ def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy: def name(self) -> str: return self.config.name - def get_cursor(self) -> Optional[Cursor]: - """ - This is a temporary hack. Because file-based, declarative, and concurrent have _slightly_ different cursor implementations + def get_cursor(self) -> Cursor | None: + """This is a temporary hack. Because file-based, declarative, and concurrent have _slightly_ different cursor implementations the file-based cursor isn't compatible with the cursor-based iteration flow in core.py top-level CDK. By setting this to None, we defer to the regular incremental checkpoint flow. Once all cursors are consolidated under a common interface then this override can be removed. diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py index fda609ae..37b68c40 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -1,11 +1,15 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy import logging -from functools import cache, lru_cache -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from collections.abc import Iterable, Mapping, MutableMapping +from functools import cache +from typing import TYPE_CHECKING, Any + +from deprecated.classic import deprecated from airbyte_cdk.models import ( AirbyteLogMessage, @@ -43,7 +47,7 @@ from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.sources.utils.slice_logger import SliceLogger -from deprecated.classic import deprecated + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( @@ -63,12 +67,10 @@ def create_from_stream( stream: AbstractFileBasedStream, source: AbstractSource, logger: logging.Logger, - state: Optional[MutableMapping[str, Any]], - cursor: "AbstractConcurrentFileBasedCursor", - ) -> "FileBasedStreamFacade": - """ - Create a ConcurrentStream from a FileBasedStream object. - """ + state: MutableMapping[str, Any] | None, + cursor: AbstractConcurrentFileBasedCursor, + ) -> FileBasedStreamFacade: + """Create a ConcurrentStream from a FileBasedStream object.""" pk = get_primary_key_from_stream(stream.primary_key) cursor_field = get_cursor_field_from_stream(stream) stream._cursor = cursor @@ -114,9 +116,7 @@ def __init__( slice_logger: SliceLogger, logger: logging.Logger, ): - """ - :param stream: The underlying AbstractStream - """ + """:param stream: The underlying AbstractStream""" self._abstract_stream = stream self._legacy_stream = legacy_stream self._cursor = cursor @@ -127,11 +127,10 @@ def __init__( self.validation_policy = legacy_stream.validation_policy @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: if self._abstract_stream.cursor_field is None: return [] - else: - return self._abstract_stream.cursor_field + return self._abstract_stream.cursor_field @property def name(self) -> str: @@ -146,7 +145,7 @@ def supports_incremental(self) -> bool: def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy: return self._legacy_stream.availability_strategy - @lru_cache(maxsize=None) + @cache def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() @@ -166,10 +165,10 @@ def get_files(self) -> Iterable[RemoteFile]: def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]: yield from self._legacy_stream.read_records_from_slice(stream_slice) # type: ignore[misc] # Only Mapping[str, Any] is expected for legacy streams, not AirbyteMessage - def compute_slices(self) -> Iterable[Optional[StreamSlice]]: + def compute_slices(self) -> Iterable[StreamSlice | None]: return self._legacy_stream.compute_slices() - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: return self._legacy_stream.infer_schema(files) def get_underlying_stream(self) -> DefaultStream: @@ -189,9 +188,9 @@ def read( def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: try: yield from self._read_records() @@ -221,12 +220,12 @@ class FileBasedStreamPartition(Partition): def __init__( self, stream: AbstractFileBasedStream, - _slice: Optional[Mapping[str, Any]], + _slice: Mapping[str, Any] | None, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], - cursor: "AbstractConcurrentFileBasedCursor", + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, + cursor: AbstractConcurrentFileBasedCursor, ): self._stream = stream self._slice = _slice @@ -280,7 +279,7 @@ def read(self) -> Iterable[Record]: else: raise e - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: if self._slice is None: return None assert ( @@ -303,11 +302,9 @@ def __hash__(self) -> int: raise ValueError( f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support." ) - else: - s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}" + s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}" return hash((self._stream.name, s)) - else: - return hash(self._stream.name) + return hash(self._stream.name) def stream_name(self) -> str: return self._stream.name @@ -326,9 +323,9 @@ def __init__( stream: AbstractFileBasedStream, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], - cursor: "AbstractConcurrentFileBasedCursor", + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, + cursor: AbstractConcurrentFileBasedCursor, ): self._stream = stream self._message_repository = message_repository diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py index ef8b290d..9b5ba76e 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping +from typing import TYPE_CHECKING, Any from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor @@ -14,6 +16,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition @@ -33,7 +36,7 @@ def observe(self, record: Record) -> None: ... def close_partition(self, partition: Partition) -> None: ... @abstractmethod - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: ... + def set_pending_partitions(self, partitions: list[FileBasedStreamPartition]) -> None: ... @abstractmethod def add_file(self, file: RemoteFile) -> None: ... diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py index e7bb2796..8356a348 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime, timedelta from threading import RLock -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager @@ -21,6 +23,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition @@ -41,7 +44,7 @@ def __init__( self, stream_config: FileBasedStreamConfig, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, stream_state: MutableMapping[str, Any], message_repository: MessageRepository, connector_state_manager: ConnectorStateManager, @@ -60,7 +63,7 @@ def __init__( ) self._state_lock = RLock() self._pending_files_lock = RLock() - self._pending_files: Optional[Dict[str, RemoteFile]] = None + self._pending_files: dict[str, RemoteFile] | None = None self._file_to_datetime_history = stream_state.get("history", {}) if stream_state else {} self._prev_cursor_value = self._compute_prev_sync_cursor(stream_state) self._sync_start = self._compute_start_time() @@ -79,7 +82,7 @@ def close_partition(self, partition: Partition) -> None: "Expected pending partitions to be set but it was not. This is unexpected. Please contact Support." ) - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: + def set_pending_partitions(self, partitions: list[FileBasedStreamPartition]) -> None: with self._pending_files_lock: self._pending_files = {} for partition in partitions: @@ -93,7 +96,7 @@ def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) - ) self._pending_files.update({file.uri: file}) - def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datetime, str]: + def _compute_prev_sync_cursor(self, value: StreamState | None) -> tuple[datetime, str]: if not value: return self.zero_value, "" prev_cursor_str = value.get(self._cursor_field.cursor_field_key) or self.zero_cursor_value @@ -112,12 +115,12 @@ def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datet cursor_dt, cursor_uri = cursor_str.split("_", 1) return datetime.strptime(cursor_dt, self.DATE_TIME_FORMAT), cursor_uri - def _get_cursor_key_from_file(self, file: Optional[RemoteFile]) -> str: + def _get_cursor_key_from_file(self, file: RemoteFile | None) -> str: if file: return f"{datetime.strftime(file.last_modified, self.DATE_TIME_FORMAT)}_{file.uri}" return self.zero_cursor_value - def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_earliest_file_in_history(self) -> RemoteFile | None: with self._state_lock: if self._file_to_datetime_history: filename, last_modified = min( @@ -127,12 +130,10 @@ def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), ) - else: - return None + return None def add_file(self, file: RemoteFile) -> None: - """ - Add a file to the cursor. This method is called when a file is processed by the stream. + """Add a file to the cursor. This method is called when a file is processed by the stream. :param file: The file to add """ if self._pending_files is None: @@ -189,20 +190,18 @@ def _get_new_cursor_value(self) -> str: # To avoid missing files, we only increment the cursor up to the oldest pending file, because we know # that all older files have been synced. return self._get_cursor_key_from_file(self._compute_earliest_pending_file()) - elif self._file_to_datetime_history: + if self._file_to_datetime_history: # If all partitions have been synced, we know that the sync is up-to-date and so can advance # the cursor to the newest file in history. return self._get_cursor_key_from_file(self._compute_latest_file_in_history()) - else: - return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_" + return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_" - def _compute_earliest_pending_file(self) -> Optional[RemoteFile]: + def _compute_earliest_pending_file(self) -> RemoteFile | None: if self._pending_files: return min(self._pending_files.values(), key=lambda x: x.last_modified) - else: - return None + return None - def _compute_latest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_latest_file_in_history(self) -> RemoteFile | None: with self._state_lock: if self._file_to_datetime_history: filename, last_modified = max( @@ -212,14 +211,12 @@ def _compute_latest_file_in_history(self) -> Optional[RemoteFile]: uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), ) - else: - return None + return None def get_files_to_sync( self, all_files: Iterable[RemoteFile], logger: logging.Logger ) -> Iterable[RemoteFile]: - """ - Given the list of files in the source, return the files that should be synced. + """Given the list of files in the source, return the files that should be synced. :param all_files: All files in the source :param logger: :return: The files that should be synced @@ -253,28 +250,23 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: ) ) return False - else: - return file.last_modified > updated_at_from_history + return file.last_modified > updated_at_from_history prev_cursor_timestamp, prev_cursor_uri = self._prev_cursor_value if self._is_history_full(): if file.last_modified > prev_cursor_timestamp: # If the history is partial and the file's datetime is strictly greater than the cursor, we should sync it return True - elif file.last_modified == prev_cursor_timestamp: + if file.last_modified == prev_cursor_timestamp: # If the history is partial and the file's datetime is equal to the earliest file in the history, # we should sync it if its uri is greater than or equal to the cursor value. return file.uri > prev_cursor_uri - else: - return file.last_modified >= self._sync_start - else: - # The file is not in the history and the history is complete. We know we need to sync the file - return True + return file.last_modified >= self._sync_start + # The file is not in the history and the history is complete. We know we need to sync the file + return True def _is_history_full(self) -> bool: - """ - Returns true if the state's history is full, meaning new entries will start to replace old entries. - """ + """Returns true if the state's history is full, meaning new entries will start to replace old entries.""" with self._state_lock: if self._file_to_datetime_history is None: raise RuntimeError( @@ -285,21 +277,18 @@ def _is_history_full(self) -> bool: def _compute_start_time(self) -> datetime: if not self._file_to_datetime_history: return datetime.min - else: - earliest = min(self._file_to_datetime_history.values()) - earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) - if self._is_history_full(): - time_window = datetime.now() - self._time_window_if_history_is_full - earliest_dt = min(earliest_dt, time_window) - return earliest_dt + earliest = min(self._file_to_datetime_history.values()) + earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) + if self._is_history_full(): + time_window = datetime.now() - self._time_window_if_history_is_full + earliest_dt = min(earliest_dt, time_window) + return earliest_dt def get_start_time(self) -> datetime: return self._sync_start def get_state(self) -> MutableMapping[str, Any]: - """ - Get the state of the cursor. - """ + """Get the state of the cursor.""" with self._state_lock: return { "history": self._file_to_datetime_history, diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py index b8926451..8a750b52 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping, Optional +from typing import TYPE_CHECKING, Any from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig @@ -18,6 +20,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition @@ -29,7 +32,7 @@ def __init__( self, stream_config: FileBasedStreamConfig, message_repository: MessageRepository, - stream_namespace: Optional[str], + stream_namespace: str | None, **kwargs: Any, ): self._stream_name = stream_config.name @@ -50,7 +53,7 @@ def observe(self, record: Record) -> None: def close_partition(self, partition: Partition) -> None: pass - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: + def set_pending_partitions(self, partitions: list[FileBasedStreamPartition]) -> None: pass def add_file(self, file: RemoteFile) -> None: diff --git a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py index 4a5eadb4..d8a46002 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import Any, Iterable, MutableMapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -13,52 +15,41 @@ class AbstractFileBasedCursor(ABC): - """ - Abstract base class for cursors used by file-based streams. - """ + """Abstract base class for cursors used by file-based streams.""" @abstractmethod def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any): - """ - Common interface for all cursors. - """ + """Common interface for all cursors.""" ... @abstractmethod def add_file(self, file: RemoteFile) -> None: - """ - Add a file to the cursor. This method is called when a file is processed by the stream. + """Add a file to the cursor. This method is called when a file is processed by the stream. :param file: The file to add """ ... @abstractmethod def set_initial_state(self, value: StreamState) -> None: - """ - Set the initial state of the cursor. The cursor cannot be initialized at construction time because the stream doesn't know its state yet. + """Set the initial state of the cursor. The cursor cannot be initialized at construction time because the stream doesn't know its state yet. :param value: The stream state """ @abstractmethod def get_state(self) -> MutableMapping[str, Any]: - """ - Get the state of the cursor. - """ + """Get the state of the cursor.""" ... @abstractmethod def get_start_time(self) -> datetime: - """ - Returns the start time of the current sync. - """ + """Returns the start time of the current sync.""" ... @abstractmethod def get_files_to_sync( self, all_files: Iterable[RemoteFile], logger: logging.Logger ) -> Iterable[RemoteFile]: - """ - Given the list of files in the source, return the files that should be synced. + """Given the list of files in the source, return the files that should be synced. :param all_files: All files in the source :param logger: :return: The files that should be synced diff --git a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py index 814bc1a1..e0802ee8 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime, timedelta -from typing import Any, Iterable, MutableMapping, Optional +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -34,7 +36,7 @@ def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): ) self._start_time = self._compute_start_time() - self._initial_earliest_file_in_history: Optional[RemoteFile] = None + self._initial_earliest_file_in_history: RemoteFile | None = None def set_initial_state(self, value: StreamState) -> None: self._file_to_datetime_history = value.get("history", {}) @@ -59,9 +61,8 @@ def get_state(self) -> StreamState: state = {"history": self._file_to_datetime_history, self.CURSOR_FIELD: self._get_cursor()} return state - def _get_cursor(self) -> Optional[str]: - """ - Returns the cursor value. + def _get_cursor(self) -> str | None: + """Returns the cursor value. Files are synced in order of last-modified with secondary sort on filename, so the cursor value is a string joining the last-modified timestamp of the last synced file and the name of the file. @@ -74,9 +75,7 @@ def _get_cursor(self) -> Optional[str]: return None def _is_history_full(self) -> bool: - """ - Returns true if the state's history is full, meaning new entries will start to replace old entries. - """ + """Returns true if the state's history is full, meaning new entries will start to replace old entries.""" return len(self._file_to_datetime_history) >= self.DEFAULT_MAX_HISTORY_SIZE def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: @@ -99,16 +98,14 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # If the history is partial and the file's datetime is strictly greater than the earliest file in the history, # we should sync it return True - elif file.last_modified == self._initial_earliest_file_in_history.last_modified: + if file.last_modified == self._initial_earliest_file_in_history.last_modified: # If the history is partial and the file's datetime is equal to the earliest file in the history, # we should sync it if its uri is strictly greater than the earliest file in the history return file.uri > self._initial_earliest_file_in_history.uri - else: - # Otherwise, only sync the file if it has been modified since the start of the time window - return file.last_modified >= self.get_start_time() - else: - # The file is not in the history and the history is complete. We know we need to sync the file - return True + # Otherwise, only sync the file if it has been modified since the start of the time window + return file.last_modified >= self.get_start_time() + # The file is not in the history and the history is complete. We know we need to sync the file + return True def get_files_to_sync( self, all_files: Iterable[RemoteFile], logger: logging.Logger @@ -126,7 +123,7 @@ def get_files_to_sync( def get_start_time(self) -> datetime: return self._start_time - def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_earliest_file_in_history(self) -> RemoteFile | None: if self._file_to_datetime_history: filename, last_modified = min( self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) @@ -134,16 +131,14 @@ def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: return RemoteFile( uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT) ) - else: - return None + return None def _compute_start_time(self) -> datetime: if not self._file_to_datetime_history: return datetime.min - else: - earliest = min(self._file_to_datetime_history.values()) - earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) - if self._is_history_full(): - time_window = datetime.now() - self._time_window_if_history_is_full - earliest_dt = min(earliest_dt, time_window) - return earliest_dt + earliest = min(self._file_to_datetime_history.values()) + earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) + if self._is_history_full(): + time_window = datetime.now() - self._time_window_if_history_is_full + earliest_dt = min(earliest_dt, time_window) + return earliest_dt diff --git a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py index a5cae2e6..969d6958 100644 --- a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py @@ -1,13 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import asyncio import itertools import traceback +from collections.abc import Iterable, Mapping, MutableMapping from copy import deepcopy from functools import cache -from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union +from typing import Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level from airbyte_cdk.models import Type as MessageType @@ -38,9 +40,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin): - """ - The default file-based stream. - """ + """The default file-based stream.""" FILE_TRANSFER_KW = "use_file_transfer" DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -66,7 +66,7 @@ def state(self, value: MutableMapping[str, Any]) -> None: self._cursor.set_initial_state(value) @property # type: ignore # mypy complains wrong type, but AbstractFileBasedCursor is parent of file-based cursors - def cursor(self) -> Optional[AbstractFileBasedCursor]: + def cursor(self) -> AbstractFileBasedCursor | None: return self._cursor @cursor.setter @@ -84,8 +84,8 @@ def primary_key(self) -> PrimaryKeyType: ) def _filter_schema_invalid_properties( - self, configured_catalog_json_schema: Dict[str, Any] - ) -> Dict[str, Any]: + self, configured_catalog_json_schema: dict[str, Any] + ) -> dict[str, Any]: if self.use_file_transfer: return { "type": "object", @@ -95,10 +95,9 @@ def _filter_schema_invalid_properties( self.ab_file_name_col: {"type": "string"}, }, } - else: - return super()._filter_schema_invalid_properties(configured_catalog_json_schema) + return super()._filter_schema_invalid_properties(configured_catalog_json_schema) - def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: + def compute_slices(self) -> Iterable[Mapping[str, Any] | None]: # Sort files by last_modified, uri and return them grouped by last_modified all_files = self.list_files() files_to_read = self._cursor.get_files_to_sync(all_files, self.logger) @@ -126,8 +125,7 @@ def transform_record_for_file_transfer( return record def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[AirbyteMessage]: - """ - Yield all records from all remote files in `list_files_for_this_sync`. + """Yield all records from all remote files in `list_files_for_this_sync`. If an error is encountered reading records from a file, log a message and do not attempt to sync the rest of the file. @@ -146,7 +144,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte try: if self.use_file_transfer: self.logger.info(f"{self.name}: {file} file-based syncing") - # todo: complete here the code to not rely on local parser + # TODO: complete here the code to not rely on local parser file_transfer = FileTransfer() for record in file_transfer.get_file( self.config, file, self.stream_reader, self.logger @@ -222,9 +220,8 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte ) @property - def cursor_field(self) -> Union[str, List[str]]: - """ - Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. + def cursor_field(self) -> str | list[str]: + """Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ return self.ab_last_mod_col @@ -256,22 +253,21 @@ def get_json_schema(self) -> JsonSchema: def _get_raw_json_schema(self) -> JsonSchema: if self.use_file_transfer: return file_transfer_schema - elif self.config.input_schema: + if self.config.input_schema: return self.config.get_input_schema() # type: ignore - elif self.config.schemaless: + if self.config.schemaless: return schemaless_schema - else: - files = self.list_files() - first_n_files = len(files) - - if self.config.recent_n_files_to_read_for_schema_discovery: - self.logger.info( - msg=( - f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema " - f"for stream {self.name} due to limitation in config." - ) + files = self.list_files() + first_n_files = len(files) + + if self.config.recent_n_files_to_read_for_schema_discovery: + self.logger.info( + msg=( + f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema " + f"for stream {self.name} due to limitation in config." ) - first_n_files = self.config.recent_n_files_to_read_for_schema_discovery + ) + first_n_files = self.config.recent_n_files_to_read_for_schema_discovery if first_n_files == 0: self.logger.warning( @@ -306,14 +302,12 @@ def _get_raw_json_schema(self) -> JsonSchema: return schema def get_files(self) -> Iterable[RemoteFile]: - """ - Return all files that belong to the stream as defined by the stream's globs. - """ + """Return all files that belong to the stream as defined by the stream's globs.""" return self.stream_reader.get_matching_files( self.config.globs or [], self.config.legacy_prefix, self.logger ) - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: loop = asyncio.get_event_loop() schema = loop.run_until_complete(self._infer_schema(files)) # as infer schema returns a Mapping that is assumed to be immutable, we need to create a deepcopy to avoid modifying the reference @@ -336,15 +330,14 @@ def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]: DefaultFileBasedStream._fill_nulls(item) return schema - async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: - """ - Infer the schema for a stream. + async def _infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: + """Infer the schema for a stream. Each file type has a corresponding `infer_schema` handler. Dispatch on file type. """ base_schema: SchemaType = {} - pending_tasks: Set[asyncio.tasks.Task[SchemaType]] = set() + pending_tasks: set[asyncio.tasks.Task[SchemaType]] = set() n_started, n_files = 0, len(files) files_iterator = iter(files) diff --git a/airbyte_cdk/sources/file_based/types.py b/airbyte_cdk/sources/file_based/types.py index b83bf37a..11bb5808 100644 --- a/airbyte_cdk/sources/file_based/types.py +++ b/airbyte_cdk/sources/file_based/types.py @@ -4,7 +4,9 @@ from __future__ import annotations -from typing import Any, Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping +from typing import Any + StreamSlice = Mapping[str, Any] StreamState = MutableMapping[str, Any] diff --git a/airbyte_cdk/sources/http_config.py b/airbyte_cdk/sources/http_config.py index 289ed9a9..3e994f7d 100644 --- a/airbyte_cdk/sources/http_config.py +++ b/airbyte_cdk/sources/http_config.py @@ -7,4 +7,7 @@ # order to fix that, we will increase the requests library pool_maxsize. As there are many pieces of code that sets a requests.Session, we # are creating this variable here so that a change in one affects the other. This can be removed once we merge how we do HTTP requests in # one piece of code or once we make connection pool size configurable for each piece of code +from __future__ import annotations + + MAX_CONNECTION_POOL_SIZE = 20 diff --git a/airbyte_cdk/sources/http_logger.py b/airbyte_cdk/sources/http_logger.py index cbdc3c68..0e246041 100644 --- a/airbyte_cdk/sources/http_logger.py +++ b/airbyte_cdk/sources/http_logger.py @@ -1,10 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import Optional, Union +from __future__ import annotations import requests + from airbyte_cdk.sources.message import LogMessage @@ -12,7 +12,7 @@ def format_http_message( response: requests.Response, title: str, description: str, - stream_name: Optional[str], + stream_name: str | None, is_auxiliary: bool = None, ) -> LogMessage: request = response.request @@ -47,5 +47,5 @@ def format_http_message( return log_message -def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]: +def _normalize_body_string(body_str: str | bytes | None) -> str | None: return body_str.decode() if isinstance(body_str, (bytes, bytearray)) else body_str diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py index 2fc156e8..d06a9094 100644 --- a/airbyte_cdk/sources/message/repository.py +++ b/airbyte_cdk/sources/message/repository.py @@ -1,17 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging from abc import ABC, abstractmethod from collections import deque -from typing import Callable, Deque, Iterable, List, Optional +from collections.abc import Callable, Iterable from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.utils.types import JsonType from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets + _LOGGER = logging.getLogger("MessageRepository") _SUPPORTED_MESSAGE_TYPES = {Type.CONTROL, Type.LOG} LogMessage = dict[str, JsonType] @@ -45,19 +47,18 @@ def _is_severe_enough(threshold: Level, level: Level) -> bool: class MessageRepository(ABC): @abstractmethod def emit_message(self, message: AirbyteMessage) -> None: - raise NotImplementedError() + raise NotImplementedError @abstractmethod def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: - """ - Computing messages can be resource consuming. This method is specialized for logging because we want to allow for lazy evaluation if + """Computing messages can be resource consuming. This method is specialized for logging because we want to allow for lazy evaluation if the log level is less severe than what is configured """ - raise NotImplementedError() + raise NotImplementedError @abstractmethod def consume_queue(self) -> Iterable[AirbyteMessage]: - raise NotImplementedError() + raise NotImplementedError class NoopMessageRepository(MessageRepository): @@ -73,7 +74,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: class InMemoryMessageRepository(MessageRepository): def __init__(self, log_level: Level = Level.INFO) -> None: - self._message_queue: Deque[AirbyteMessage] = deque() + self._message_queue: deque[AirbyteMessage] = deque() self._log_level = log_level def emit_message(self, message: AirbyteMessage) -> None: @@ -119,7 +120,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: return self._decorated.consume_queue() def _append_second_to_first( - self, first: LogMessage, second: LogMessage, path: Optional[List[str]] = None + self, first: LogMessage, second: LogMessage, path: list[str] | None = None ) -> LogMessage: if path is None: path = [] diff --git a/airbyte_cdk/sources/source.py b/airbyte_cdk/sources/source.py index 2958d82c..390af101 100644 --- a/airbyte_cdk/sources/source.py +++ b/airbyte_cdk/sources/source.py @@ -1,11 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, List, Mapping, Optional, TypeVar +from collections.abc import Iterable, Mapping +from typing import Any, Generic, TypeVar from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig from airbyte_cdk.models import ( @@ -17,6 +18,7 @@ ConfiguredAirbyteCatalogSerializer, ) + TState = TypeVar("TState") TCatalog = TypeVar("TCatalog") @@ -38,30 +40,26 @@ def read( logger: logging.Logger, config: TConfig, catalog: TCatalog, - state: Optional[TState] = None, + state: TState | None = None, ) -> Iterable[AirbyteMessage]: - """ - Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state. - """ + """Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state.""" @abstractmethod def discover(self, logger: logging.Logger, config: TConfig) -> AirbyteCatalog: - """ - Returns an AirbyteCatalog representing the available streams and fields in this integration. For example, given valid credentials to a + """Returns an AirbyteCatalog representing the available streams and fields in this integration. For example, given valid credentials to a Postgres database, returns an Airbyte catalog where each postgres table is a stream, and each table column is a field. """ class Source( DefaultConnectorMixin, - BaseSource[Mapping[str, Any], List[AirbyteStateMessage], ConfiguredAirbyteCatalog], + BaseSource[Mapping[str, Any], list[AirbyteStateMessage], ConfiguredAirbyteCatalog], ABC, ): # can be overridden to change an input state. @classmethod - def read_state(cls, state_path: str) -> List[AirbyteStateMessage]: - """ - Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either + def read_state(cls, state_path: str) -> list[AirbyteStateMessage]: + """Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the incoming input type, it will always be transformed and output as a list of AirbyteStateMessage(s). :param state_path: The filepath to where the stream states are located diff --git a/airbyte_cdk/sources/streams/availability_strategy.py b/airbyte_cdk/sources/streams/availability_strategy.py index 312ddae1..04284fcf 100644 --- a/airbyte_cdk/sources/streams/availability_strategy.py +++ b/airbyte_cdk/sources/streams/availability_strategy.py @@ -1,30 +1,30 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import typing from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.streams.core import Stream, StreamData + if typing.TYPE_CHECKING: from airbyte_cdk.sources import Source class AvailabilityStrategy(ABC): - """ - Abstract base class for checking stream availability. - """ + """Abstract base class for checking stream availability.""" @abstractmethod def check_availability( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: - """ - Checks stream availability. + self, stream: Stream, logger: logging.Logger, source: Source | None = None + ) -> tuple[bool, str | None]: + """Checks stream availability. :param stream: stream :param logger: source logger @@ -36,9 +36,8 @@ def check_availability( """ @staticmethod - def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]: - """ - Gets the first stream_slice from a given stream's stream_slices. + def get_first_stream_slice(stream: Stream) -> Mapping[str, Any] | None: + """Gets the first stream_slice from a given stream's stream_slices. :param stream: stream :raises StopIteration: if there is no first slice to return (the stream_slices generator is empty) :return: first stream slice from 'stream_slices' generator (`None` is a valid stream slice) @@ -55,10 +54,9 @@ def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]: @staticmethod def get_first_record_for_slice( - stream: Stream, stream_slice: Optional[Mapping[str, Any]] + stream: Stream, stream_slice: Mapping[str, Any] | None ) -> StreamData: - """ - Gets the first record for a stream_slice of a stream. + """Gets the first record for a stream_slice of a stream. :param stream: stream instance from which to read records :param stream_slice: stream_slice parameters for slicing the stream diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py index 19ae603c..47fb013a 100644 --- a/airbyte_cdk/sources/streams/call_rate.py +++ b/airbyte_cdk/sources/streams/call_rate.py @@ -1,24 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import abc import dataclasses import datetime import logging import time +from collections.abc import Mapping from datetime import timedelta from threading import RLock -from typing import TYPE_CHECKING, Any, Mapping, Optional +from typing import TYPE_CHECKING, Any from urllib import parse import requests import requests_cache -from pyrate_limiter import InMemoryBucket, Limiter +from pyrate_limiter import InMemoryBucket, Limiter, RateItem, TimeClock from pyrate_limiter import Rate as PyRateRate -from pyrate_limiter import RateItem, TimeClock from pyrate_limiter.exceptions import BucketFullException + # prevents mypy from complaining about missing session attributes in LimiterMixin if TYPE_CHECKING: MIXIN_BASE = requests.Session @@ -76,9 +78,7 @@ def try_acquire(self, request: Any, weight: int) -> None: """ @abc.abstractmethod - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Update call rate counting with current values :param available_calls: @@ -91,9 +91,7 @@ class RequestMatcher(abc.ABC): @abc.abstractmethod def __call__(self, request: Any) -> bool: - """ - - :param request: + """:param request: :return: True if matches the provided request object, False - otherwise """ @@ -103,10 +101,10 @@ class HttpRequestMatcher(RequestMatcher): def __init__( self, - method: Optional[str] = None, - url: Optional[str] = None, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, Any]] = None, + method: str | None = None, + url: str | None = None, + params: Mapping[str, Any] | None = None, + headers: Mapping[str, Any] | None = None, ): """Constructor @@ -131,9 +129,7 @@ def _match_dict(obj: Mapping[str, Any], pattern: Mapping[str, Any]) -> bool: return pattern.items() <= obj.items() def __call__(self, request: Any) -> bool: - """ - - :param request: + """:param request: :return: True if matches the provided request object, False - otherwise """ if isinstance(request, requests.Request): @@ -171,19 +167,16 @@ def matches(self, request: Any) -> bool: :param request: :return: True if policy should apply to this request, False - otherwise """ - if not self._matchers: return True return any(matcher(request) for matcher in self._matchers) class UnlimitedCallRatePolicy(BaseCallRatePolicy): - """ - This policy is for explicit unlimited call rates. + """This policy is for explicit unlimited call rates. It can be used when we want to match a specific group of requests and don't apply any limits. Example: - APICallBudget( [ UnlimitedCallRatePolicy( @@ -204,9 +197,7 @@ class UnlimitedCallRatePolicy(BaseCallRatePolicy): def try_acquire(self, request: Any, weight: int) -> None: """Do nothing""" - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Do nothing""" @@ -225,7 +216,6 @@ def __init__( :param call_limit: :param matchers: """ - self._next_reset_ts = next_reset_ts self._offset = period self._call_limit = call_limit @@ -258,9 +248,7 @@ def try_acquire(self, request: Any, weight: int) -> None: self._calls_num += weight - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Update call rate counters, by default, only reacts to decreasing updates of available_calls and changes to call_reset_ts. We ignore updates with available_calls > current_available_calls to support call rate limits that are lower than API limits. @@ -296,8 +284,7 @@ def _update_current_window(self) -> None: class MovingWindowCallRatePolicy(BaseCallRatePolicy): - """ - Policy to control requests rate implemented on top of PyRateLimiter lib. + """Policy to control requests rate implemented on top of PyRateLimiter lib. The main difference between this policy and FixedWindowCallRatePolicy is that the rate-limiting window is moving along requests that we made, and there is no moment when we reset an available number of calls. This strategy requires saving of timestamps of all requests within a window. @@ -342,9 +329,7 @@ def try_acquire(self, request: Any, weight: int) -> None: time_to_wait=timedelta(milliseconds=time_to_wait), ) - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Adjust call bucket to reflect the state of the API server :param available_calls: @@ -376,9 +361,7 @@ class AbstractAPIBudget(abc.ABC): """ @abc.abstractmethod - def acquire_call( - self, request: Any, block: bool = True, timeout: Optional[float] = None - ) -> None: + def acquire_call(self, request: Any, block: bool = True, timeout: float | None = None) -> None: """Try to get a call from budget, will block by default :param request: @@ -388,7 +371,7 @@ def acquire_call( """ @abc.abstractmethod - def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: + def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: """Find matching call rate policy for specific request""" @abc.abstractmethod @@ -412,19 +395,16 @@ def __init__( :param maximum_attempts_to_acquire: number of attempts before throwing hit ratelimit exception, we put some big number here to avoid situations when many threads compete with each other for a few lots over a significant amount of time """ - self._policies = policies self._maximum_attempts_to_acquire = maximum_attempts_to_acquire - def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: + def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: for policy in self._policies: if policy.matches(request): return policy return None - def acquire_call( - self, request: Any, block: bool = True, timeout: Optional[float] = None - ) -> None: + def acquire_call(self, request: Any, block: bool = True, timeout: float | None = None) -> None: """Try to get a call from budget, will block by default. Matchers will be called sequentially in the same order they were added. The first matcher that returns True will @@ -434,7 +414,6 @@ def acquire_call( :param timeout: if provided will limit maximum time in block, otherwise will wait until credit is available :raises: CallRateLimitHit - when no calls left and if timeout was set the waiting time exceed the timeout """ - policy = self.get_matching_policy(request) if policy: self._do_acquire(request=request, policy=policy, block=block, timeout=timeout) @@ -450,7 +429,7 @@ def update_from_response(self, request: Any, response: Any) -> None: pass def _do_acquire( - self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float] + self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: float | None ) -> None: """Internal method to try to acquire a call credit @@ -521,16 +500,14 @@ def update_from_response(self, request: Any, response: Any) -> None: reset_ts = self.get_reset_ts_from_response(response) policy.update(available_calls=available_calls, call_reset_ts=reset_ts) - def get_reset_ts_from_response( - self, response: requests.Response - ) -> Optional[datetime.datetime]: + def get_reset_ts_from_response(self, response: requests.Response) -> datetime.datetime | None: if response.headers.get(self._ratelimit_reset_header): return datetime.datetime.fromtimestamp( int(response.headers[self._ratelimit_reset_header]) ) return None - def get_calls_left_from_response(self, response: requests.Response) -> Optional[int]: + def get_calls_left_from_response(self, response: requests.Response) -> int | None: if response.headers.get(self._ratelimit_remaining_header): return int(response.headers[self._ratelimit_remaining_header]) diff --git a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py index 6e4ef98d..b5e72ac4 100644 --- a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py +++ b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py @@ -1,12 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping from enum import Enum -from typing import Any, Iterable, Mapping, Optional - -from airbyte_cdk.sources.types import StreamSlice +from typing import Any from .cursor import Cursor +from airbyte_cdk.sources.types import StreamSlice class CheckpointMode(Enum): @@ -19,48 +20,42 @@ class CheckpointMode(Enum): class CheckpointReader(ABC): - """ - CheckpointReader manages how to iterate over a stream's partitions and serves as the bridge for interpreting the current state + """CheckpointReader manages how to iterate over a stream's partitions and serves as the bridge for interpreting the current state of the stream that should be emitted back to the platform. """ @abstractmethod - def next(self) -> Optional[Mapping[str, Any]]: - """ - Returns the next slice that will be used to fetch the next group of records. Returning None indicates that the reader + def next(self) -> Mapping[str, Any] | None: + """Returns the next slice that will be used to fetch the next group of records. Returning None indicates that the reader has finished iterating over all slices. """ @abstractmethod def observe(self, new_state: Mapping[str, Any]) -> None: - """ - Updates the internal state of the checkpoint reader based on the incoming stream state from a connector. + """Updates the internal state of the checkpoint reader based on the incoming stream state from a connector. WARNING: This is used to retain backwards compatibility with streams using the legacy get_stream_state() method. In order to uptake Resumable Full Refresh, connectors must migrate streams to use the state setter/getter methods. """ @abstractmethod - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: - """ - Retrieves the current state value of the stream. The connector does not emit state messages if the checkpoint value is None. - """ + def get_checkpoint(self) -> Mapping[str, Any] | None: + """Retrieves the current state value of the stream. The connector does not emit state messages if the checkpoint value is None.""" class IncrementalCheckpointReader(CheckpointReader): - """ - IncrementalCheckpointReader handles iterating through a stream based on partitioned windows of data that are determined + """IncrementalCheckpointReader handles iterating through a stream based on partitioned windows of data that are determined before syncing data. """ def __init__( - self, stream_state: Mapping[str, Any], stream_slices: Iterable[Optional[Mapping[str, Any]]] + self, stream_state: Mapping[str, Any], stream_slices: Iterable[Mapping[str, Any] | None] ): - self._state: Optional[Mapping[str, Any]] = stream_state + self._state: Mapping[str, Any] | None = stream_state self._stream_slices = iter(stream_slices) self._has_slices = False - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: next_slice = next(self._stream_slices) self._has_slices = True @@ -76,13 +71,12 @@ def next(self) -> Optional[Mapping[str, Any]]: def observe(self, new_state: Mapping[str, Any]) -> None: self._state = new_state - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: return self._state class CursorBasedCheckpointReader(CheckpointReader): - """ - CursorBasedCheckpointReader is used by streams that implement a Cursor in order to manage state. This allows the checkpoint + """CursorBasedCheckpointReader is used by streams that implement a Cursor in order to manage state. This allows the checkpoint reader to delegate the complexity of fetching state to the cursor and focus on the iteration over a stream's partitions. This reader supports the Cursor interface used by Python and low-code sources. Not to be confused with Cursor interface @@ -92,7 +86,7 @@ class CursorBasedCheckpointReader(CheckpointReader): def __init__( self, cursor: Cursor, - stream_slices: Iterable[Optional[Mapping[str, Any]]], + stream_slices: Iterable[Mapping[str, Any] | None], read_state_from_cursor: bool = False, ): self._cursor = cursor @@ -100,11 +94,11 @@ def __init__( # read_state_from_cursor is used to delineate that partitions should determine when to stop syncing dynamically according # to the value of the state at runtime. This currently only applies to streams that use resumable full refresh. self._read_state_from_cursor = read_state_from_cursor - self._current_slice: Optional[StreamSlice] = None + self._current_slice: StreamSlice | None = None self._finished_sync = False - self._previous_state: Optional[Mapping[str, Any]] = None + self._previous_state: Mapping[str, Any] | None = None - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: self.current_slice = self._find_next_slice() return self.current_slice @@ -117,18 +111,16 @@ def observe(self, new_state: Mapping[str, Any]) -> None: # while processing records pass - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: # This is used to avoid sending a duplicate state messages new_state = self._cursor.get_stream_state() if new_state != self._previous_state: self._previous_state = new_state return new_state - else: - return None + return None def _find_next_slice(self) -> StreamSlice: - """ - _find_next_slice() returns the next slice of data should be synced for the current stream according to its cursor. + """_find_next_slice() returns the next slice of data should be synced for the current stream according to its cursor. This function supports iterating over a stream's slices across two dimensions. The first dimension is the stream's partitions like parent records for a substream. The inner dimension iterates over the cursor value like a date range for incremental streams or a pagination checkpoint for resumable full refresh. @@ -145,7 +137,6 @@ def _find_next_slice(self) -> StreamSlice: 3. When stream has processed all partitions, the iterator will raise a StopIteration exception signaling there are no more slices left for extracting more records. """ - if self._read_state_from_cursor: if self.current_slice is None: # current_slice is None represents the first time we are iterating over a stream's slices. The first slice to @@ -165,36 +156,34 @@ def _find_next_slice(self) -> StreamSlice: partition=next_slice.partition, extra_fields=next_slice.extra_fields, ) - else: - state_for_slice = self._cursor.select_state(self.current_slice) - if state_for_slice == FULL_REFRESH_COMPLETE_STATE: - # If the current slice is is complete, move to the next slice and skip the next slices that already - # have the terminal complete value indicating that a previous attempt was successfully read. - # Dummy initialization for mypy since we'll iterate at least once to get the next slice - next_candidate_slice = StreamSlice(cursor_slice={}, partition={}) - has_more = True - while has_more: - next_candidate_slice = self.read_and_convert_slice() - state_for_slice = self._cursor.select_state(next_candidate_slice) - has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE - return StreamSlice( - cursor_slice=state_for_slice or {}, - partition=next_candidate_slice.partition, - extra_fields=next_candidate_slice.extra_fields, - ) - # The reader continues to process the current partition if it's state is still in progress + state_for_slice = self._cursor.select_state(self.current_slice) + if state_for_slice == FULL_REFRESH_COMPLETE_STATE: + # If the current slice is is complete, move to the next slice and skip the next slices that already + # have the terminal complete value indicating that a previous attempt was successfully read. + # Dummy initialization for mypy since we'll iterate at least once to get the next slice + next_candidate_slice = StreamSlice(cursor_slice={}, partition={}) + has_more = True + while has_more: + next_candidate_slice = self.read_and_convert_slice() + state_for_slice = self._cursor.select_state(next_candidate_slice) + has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE return StreamSlice( cursor_slice=state_for_slice or {}, - partition=self.current_slice.partition, - extra_fields=self.current_slice.extra_fields, + partition=next_candidate_slice.partition, + extra_fields=next_candidate_slice.extra_fields, ) - else: - # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate - # on a fixed set of slices determined before reading records. They just iterate to the next slice - return self.read_and_convert_slice() + # The reader continues to process the current partition if it's state is still in progress + return StreamSlice( + cursor_slice=state_for_slice or {}, + partition=self.current_slice.partition, + extra_fields=self.current_slice.extra_fields, + ) + # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate + # on a fixed set of slices determined before reading records. They just iterate to the next slice + return self.read_and_convert_slice() @property - def current_slice(self) -> Optional[StreamSlice]: + def current_slice(self) -> StreamSlice | None: return self._current_slice @current_slice.setter @@ -211,8 +200,7 @@ def read_and_convert_slice(self) -> StreamSlice: class LegacyCursorBasedCheckpointReader(CursorBasedCheckpointReader): - """ - This (unfortunate) class operates like an adapter to retain backwards compatibility with legacy sources that take in stream_slice + """This (unfortunate) class operates like an adapter to retain backwards compatibility with legacy sources that take in stream_slice in the form of a Mapping instead of the StreamSlice object. Internally, the reader still operates over StreamSlices, but it is instantiated with and emits stream slices in the form of a Mapping[str, Any]. The logic of how partitions and cursors are iterated over is synonymous with CursorBasedCheckpointReader. @@ -234,7 +222,7 @@ class LegacyCursorBasedCheckpointReader(CursorBasedCheckpointReader): def __init__( self, cursor: Cursor, - stream_slices: Iterable[Optional[Mapping[str, Any]]], + stream_slices: Iterable[Mapping[str, Any] | None], read_state_from_cursor: bool = False, ): super().__init__( @@ -243,13 +231,13 @@ def __init__( read_state_from_cursor=read_state_from_cursor, ) - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: self.current_slice = self._find_next_slice() if "partition" in dict(self.current_slice): raise ValueError("Stream is configured to use invalid stream slice key 'partition'") - elif "cursor_slice" in dict(self.current_slice): + if "cursor_slice" in dict(self.current_slice): raise ValueError( "Stream is configured to use invalid stream slice key 'cursor_slice'" ) @@ -281,8 +269,7 @@ def read_and_convert_slice(self) -> StreamSlice: class ResumableFullRefreshCheckpointReader(CheckpointReader): - """ - ResumableFullRefreshCheckpointReader allows for iteration over an unbounded set of records based on the pagination strategy + """ResumableFullRefreshCheckpointReader allows for iteration over an unbounded set of records based on the pagination strategy of the stream. Because the number of pages is unknown, the stream's current state is used to determine whether to continue fetching more pages or stopping the sync. """ @@ -293,33 +280,31 @@ def __init__(self, stream_state: Mapping[str, Any]): self._first_page = bool(stream_state == {}) self._state: Mapping[str, Any] = stream_state - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: if self._first_page: self._first_page = False return self._state - elif self._state == FULL_REFRESH_COMPLETE_STATE: + if self._state == FULL_REFRESH_COMPLETE_STATE: return None - else: - return self._state + return self._state def observe(self, new_state: Mapping[str, Any]) -> None: self._state = new_state - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: return self._state or {} class FullRefreshCheckpointReader(CheckpointReader): - """ - FullRefreshCheckpointReader iterates over data that cannot be checkpointed incrementally during the sync because the stream + """FullRefreshCheckpointReader iterates over data that cannot be checkpointed incrementally during the sync because the stream is not capable of managing state. At the end of a sync, a final state message is emitted to signal completion. """ - def __init__(self, stream_slices: Iterable[Optional[Mapping[str, Any]]]): + def __init__(self, stream_slices: Iterable[Mapping[str, Any] | None]): self._stream_slices = iter(stream_slices) self._final_checkpoint = False - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: return next(self._stream_slices) except StopIteration: @@ -329,7 +314,7 @@ def next(self) -> Optional[Mapping[str, Any]]: def observe(self, new_state: Mapping[str, Any]) -> None: pass - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: if self._final_checkpoint: return {"__ab_no_cursor_state_message": True} return None diff --git a/airbyte_cdk/sources/streams/checkpoint/cursor.py b/airbyte_cdk/sources/streams/checkpoint/cursor.py index 6d758bf4..9a059fee 100644 --- a/airbyte_cdk/sources/streams/checkpoint/cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/cursor.py @@ -1,31 +1,29 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.types import Record, StreamSlice, StreamState class Cursor(ABC): - """ - Cursors are components that allow for checkpointing the current state of a sync. They keep track of what data has been consumed + """Cursors are components that allow for checkpointing the current state of a sync. They keep track of what data has been consumed and allows for syncs to be resumed from a specific point based on that information. """ @abstractmethod def set_initial_state(self, stream_state: StreamState) -> None: - """ - Cursors are not initialized with their state. As state is needed in order to function properly, this method should be called + """Cursors are not initialized with their state. As state is needed in order to function properly, this method should be called before calling anything else :param stream_state: The state of the stream as returned by get_stream_state """ def observe(self, stream_slice: StreamSlice, record: Record) -> None: - """ - Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read. + """Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read. :param stream_slice: The current slice, which may or may not contain the most recently observed record :param record: the most recently-read record, which the cursor can use to update the stream state. Outwardly-visible changes to the @@ -35,8 +33,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: @abstractmethod def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: - """ - Update state based on the stream slice. Note that `stream_slice.cursor_slice` and `most_recent_record.associated_slice` are expected + """Update state based on the stream slice. Note that `stream_slice.cursor_slice` and `most_recent_record.associated_slice` are expected to be the same but we make it explicit here that `stream_slice` should be leveraged to update the state. We do not pass in the latest record, since cursor instances should maintain the relevant internal state on their own. @@ -45,8 +42,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: @abstractmethod def get_stream_state(self) -> StreamState: - """ - Returns the current stream state. We would like to restrict it's usage since it does expose internal of state. As of 2023-06-14, it + """Returns the current stream state. We would like to restrict it's usage since it does expose internal of state. As of 2023-06-14, it is used for two things: * Interpolation of the requests * Transformation of records @@ -58,20 +54,15 @@ def get_stream_state(self) -> StreamState: @abstractmethod def should_be_synced(self, record: Record) -> bool: - """ - Evaluating if a record should be synced allows for filtering and stop condition on pagination - """ + """Evaluating if a record should be synced allows for filtering and stop condition on pagination""" @abstractmethod def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: - """ - Evaluating which record is greater in terms of cursor. This is used to avoid having to capture all the records to close a slice - """ + """Evaluating which record is greater in terms of cursor. This is used to avoid having to capture all the records to close a slice""" @abstractmethod - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: - """ - Get the state value of a specific stream_slice. For incremental or resumable full refresh cursors which only manage state in + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: + """Get the state value of a specific stream_slice. For incremental or resumable full refresh cursors which only manage state in a single dimension this is the entire state object. For per-partition cursors used by substreams, this returns the state of a specific parent delineated by the incoming slice's partition object. """ diff --git a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py index e0dee4a9..b79e5bee 100644 --- a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py +++ b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py @@ -1,12 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import json -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class PerPartitionKeySerializer: - """ - We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this + """We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this concern, we wanted to use dictionaries to map `partition -> cursor`. However, partitions are dict and dict can't be used as dict keys since they are not hashable. By creating json string using the dict, we can have a use the dict as a key to the dict since strings are hashable. diff --git a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py index 86abd253..108ef260 100644 --- a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.streams.checkpoint import Cursor from airbyte_cdk.sources.types import Record, StreamSlice, StreamState @@ -9,8 +10,7 @@ @dataclass class ResumableFullRefreshCursor(Cursor): - """ - Cursor that allows for the checkpointing of sync progress according to a synthetic cursor based on the pagination state + """Cursor that allows for the checkpointing of sync progress according to a synthetic cursor based on the pagination state of the stream. Resumable full refresh syncs are only intended to retain state in between sync attempts of the same job with the platform responsible for removing said state. """ @@ -25,27 +25,22 @@ def set_initial_state(self, stream_state: StreamState) -> None: self._cursor = stream_state def observe(self, stream_slice: StreamSlice, record: Record) -> None: - """ - Resumable full refresh manages state using a page number so it does not need to update state by observing incoming records. - """ + """Resumable full refresh manages state using a page number so it does not need to update state by observing incoming records.""" pass def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: self._cursor = stream_slice.cursor_slice def should_be_synced(self, record: Record) -> bool: - """ - Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages + """Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: - """ - RFR record don't have ordering to be compared between one another. - """ + """RFR record don't have ordering to be compared between one another.""" return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # A top-level RFR cursor only manages the state of a single partition return self._cursor diff --git a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py index 9966959f..0bcb0adf 100644 --- a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py @@ -1,7 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass -from typing import Any, Mapping, MutableMapping, Optional +from typing import Any from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.checkpoint import Cursor @@ -11,6 +13,7 @@ from airbyte_cdk.sources.types import Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException + FULL_REFRESH_COMPLETE_STATE: Mapping[str, Any] = {"__ab_full_refresh_sync_complete": True} @@ -24,8 +27,7 @@ def get_stream_state(self) -> StreamState: return {"states": list(self._per_partition_state.values())} def set_initial_state(self, stream_state: StreamState) -> None: - """ - Set the initial state for the cursors. + """Set the initial state for the cursors. This method initializes the state for each partition cursor using the provided stream state. If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state. @@ -71,9 +73,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: self._per_partition_state[self._to_partition_key(state["partition"])] = state def observe(self, stream_slice: StreamSlice, record: Record) -> None: - """ - Substream resumable full refresh manages state by closing the slice after syncing a parent so observe is not used. - """ + """Substream resumable full refresh manages state by closing the slice after syncing a parent so observe is not used.""" pass def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: @@ -83,19 +83,16 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: } def should_be_synced(self, record: Record) -> bool: - """ - Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages + """Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: - """ - RFR record don't have ordering to be compared between one another. - """ + """RFR record don't have ordering to be compared between one another.""" return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py index da99ae10..9f113cf8 100644 --- a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py @@ -1,22 +1,24 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any + +from deprecated.classic import deprecated from airbyte_cdk.models import AirbyteStream from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.streams.concurrent.availability_strategy import StreamAvailability from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition -from deprecated.classic import deprecated @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) class AbstractStream(ABC): - """ - AbstractStream is an experimental interface for streams developed as part of the Concurrent CDK. + """AbstractStream is an experimental interface for streams developed as part of the Concurrent CDK. This interface is not yet stable and may change in the future. Use at your own risk. Why create a new interface instead of adding concurrency capabilities the existing Stream? @@ -40,53 +42,39 @@ class AbstractStream(ABC): @abstractmethod def generate_partitions(self) -> Iterable[Partition]: - """ - Generates the partitions that will be read by this stream. + """Generates the partitions that will be read by this stream. :return: An iterable of partitions. """ @property @abstractmethod def name(self) -> str: - """ - :return: The stream name - """ + """:return: The stream name""" @property @abstractmethod - def cursor_field(self) -> Optional[str]: - """ - Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. + def cursor_field(self) -> str | None: + """Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. Nested cursor fields are not supported. """ @abstractmethod def check_availability(self) -> StreamAvailability: - """ - :return: The stream's availability - """ + """:return: The stream's availability""" @abstractmethod def get_json_schema(self) -> Mapping[str, Any]: - """ - :return: A dict of the JSON schema representing this stream. - """ + """:return: A dict of the JSON schema representing this stream.""" @abstractmethod def as_airbyte_stream(self) -> AirbyteStream: - """ - :return: A dict of the JSON schema representing this stream. - """ + """:return: A dict of the JSON schema representing this stream.""" @abstractmethod def log_stream_sync_configuration(self) -> None: - """ - Logs the stream's configuration for debugging purposes. - """ + """Logs the stream's configuration for debugging purposes.""" @property @abstractmethod def cursor(self) -> Cursor: - """ - :return: The cursor associated with this stream. - """ + """:return: The cursor associated with this stream.""" diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py index 18cacbc5..8745c5d8 100644 --- a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py +++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py @@ -1,19 +1,19 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage + StreamType = TypeVar("StreamType") -class AbstractStreamFacade(Generic[StreamType], ABC): +class AbstractStreamFacade(ABC, Generic[StreamType]): @abstractmethod def get_underlying_stream(self) -> StreamType: - """ - Return the underlying stream facade object. - """ + """Return the underlying stream facade object.""" ... @property @@ -21,9 +21,8 @@ def source_defined_cursor(self) -> bool: # Streams must be aware of their cursor at instantiation time return True - def get_error_display_message(self, exception: BaseException) -> Optional[str]: - """ - Retrieves the user-friendly display message that corresponds to an exception. + def get_error_display_message(self, exception: BaseException) -> str | None: + """Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage. @@ -33,5 +32,4 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: """ if isinstance(exception, ExceptionWithDisplayMessage): return exception.display_message - else: - return None + return None diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index d4b539a5..0b15be5e 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy import json import logging -from functools import lru_cache -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping, MutableMapping +from functools import cache +from typing import Any + +from deprecated.classic import deprecated from airbyte_cdk.models import ( AirbyteLogMessage, @@ -45,7 +49,7 @@ from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.sources.utils.slice_logger import SliceLogger -from deprecated.classic import deprecated + """ This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream @@ -54,8 +58,7 @@ @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) class StreamFacade(AbstractStreamFacade[DefaultStream], Stream): - """ - The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream. + """The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream. All methods either delegate to the wrapped AbstractStream or provide a default implementation. The default implementations define restrictions imposed on Streams migrated to the new interface. For instance, only source-defined cursors are supported. @@ -67,11 +70,10 @@ def create_from_stream( stream: Stream, source: AbstractSource, logger: logging.Logger, - state: Optional[MutableMapping[str, Any]], + state: MutableMapping[str, Any] | None, cursor: Cursor, ) -> Stream: - """ - Create a ConcurrentStream from a Stream object. + """Create a ConcurrentStream from a Stream object. :param source: The source :param stream: The stream :param max_workers: The maximum number of worker thread to use @@ -132,9 +134,7 @@ def __init__( slice_logger: SliceLogger, logger: logging.Logger, ): - """ - :param stream: The underlying AbstractStream - """ + """:param stream: The underlying AbstractStream""" self._abstract_stream = stream self._legacy_stream = legacy_stream self._cursor = cursor @@ -155,9 +155,9 @@ def read( def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: try: yield from self._read_records() @@ -187,22 +187,21 @@ def name(self) -> str: return self._abstract_stream.name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: # This method is not expected to be called directly. It is only implemented for backward compatibility with the old interface return self.as_airbyte_stream().source_defined_primary_key # type: ignore # source_defined_primary_key is known to be an Optional[List[List[str]]] @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: if self._abstract_stream.cursor_field is None: return [] - else: - return self._abstract_stream.cursor_field + return self._abstract_stream.cursor_field @property - def cursor(self) -> Optional[Cursor]: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor + def cursor(self) -> Cursor | None: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor return self._cursor - @lru_cache(maxsize=None) + @cache def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() @@ -211,10 +210,9 @@ def supports_incremental(self) -> bool: return self._legacy_stream.supports_incremental def check_availability( - self, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: - """ - Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters + self, logger: logging.Logger, source: Source | None = None + ) -> tuple[bool, str | None]: + """Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters :param logger: (ignored) :param source: (ignored) :return: @@ -242,8 +240,7 @@ def default(self, obj: Any) -> Any: class StreamPartition(Partition): - """ - This class acts as an adapter between the new Partition interface and the Stream's stream_slice interface + """This class acts as an adapter between the new Partition interface and the Stream's stream_slice interface StreamPartitions are instantiated from a Stream and a stream_slice. @@ -254,15 +251,14 @@ class StreamPartition(Partition): def __init__( self, stream: Stream, - _slice: Optional[Mapping[str, Any]], + _slice: Mapping[str, Any] | None, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, cursor: Cursor, ): - """ - :param stream: The stream to delegate to + """:param stream: The stream to delegate to :param _slice: The partition's stream_slice :param message_repository: The message repository to use to emit non-record messages """ @@ -276,8 +272,7 @@ def __init__( self._is_closed = False def read(self) -> Iterable[Record]: - """ - Read messages from the stream. + """Read messages from the stream. If the StreamData is a Mapping, it will be converted to a Record. Otherwise, the message will be emitted on the message repository. """ @@ -309,7 +304,7 @@ def read(self) -> Iterable[Record]: else: raise e - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: return self._slice def __hash__(self) -> int: @@ -317,8 +312,7 @@ def __hash__(self) -> int: # Convert the slice to a string so that it can be hashed s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder) return hash((self._stream.name, s)) - else: - return hash(self._stream.name) + return hash(self._stream.name) def stream_name(self) -> str: return self._stream.name @@ -335,8 +329,7 @@ def __repr__(self) -> str: class StreamPartitionGenerator(PartitionGenerator): - """ - This class acts as an adapter between the new PartitionGenerator and Stream.stream_slices + """This class acts as an adapter between the new PartitionGenerator and Stream.stream_slices This class can be used to help enable concurrency on existing connectors without having to rewrite everything as AbstractStream. In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. @@ -347,12 +340,11 @@ def __init__( stream: Stream, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, cursor: Cursor, ): - """ - :param stream: The stream to delegate to + """:param stream: The stream to delegate to :param message_repository: The message repository to use to emit non-record messages """ self.message_repository = message_repository @@ -378,8 +370,7 @@ def generate(self) -> Iterable[Partition]: class CursorPartitionGenerator(PartitionGenerator): - """ - This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions. + """This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions. It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state. @@ -394,11 +385,10 @@ def __init__( message_repository: MessageRepository, cursor: Cursor, connector_state_converter: DateTimeStreamStateConverter, - cursor_field: Optional[List[str]], - slice_boundary_fields: Optional[Tuple[str, str]], + cursor_field: list[str] | None, + slice_boundary_fields: tuple[str, str] | None, ): - """ - Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor. + """Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor. :param stream: The stream to delegate to for partition generation. :param message_repository: The message repository to use to emit non-record messages. @@ -415,15 +405,13 @@ def __init__( self._connector_state_converter = connector_state_converter def generate(self) -> Iterable[Partition]: - """ - Generate partitions based on the slices in the cursor's state. + """Generate partitions based on the slices in the cursor's state. This method iterates through the list of slices found in the cursor's state, and for each slice, it generates a `StreamPartition` object. :return: An iterable of StreamPartition objects. """ - start_boundary = ( self._slice_boundary_fields[self._START_BOUNDARY] if self._slice_boundary_fields @@ -464,10 +452,9 @@ def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy) self._abstract_availability_strategy = abstract_availability_strategy def check_availability( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: - """ - Checks stream availability. + self, stream: Stream, logger: logging.Logger, source: Source | None = None + ) -> tuple[bool, str | None]: + """Checks stream availability. Important to note that the stream and source parameters are not used by the underlying AbstractAvailabilityStrategy. diff --git a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py index 098b24ce..22d14fd9 100644 --- a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py +++ b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py @@ -1,34 +1,31 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Optional -from airbyte_cdk.sources.source import ExperimentalClassWarning from deprecated.classic import deprecated +from airbyte_cdk.sources.source import ExperimentalClassWarning + class StreamAvailability(ABC): @abstractmethod def is_available(self) -> bool: - """ - :return: True if the stream is available. False if the stream is not - """ + """:return: True if the stream is available. False if the stream is not""" @abstractmethod - def message(self) -> Optional[str]: - """ - :return: A message describing why the stream is not available. If the stream is available, this should return None. - """ + def message(self) -> str | None: + """:return: A message describing why the stream is not available. If the stream is available, this should return None.""" class StreamAvailable(StreamAvailability): def is_available(self) -> bool: return True - def message(self) -> Optional[str]: + def message(self) -> str | None: return None @@ -39,7 +36,7 @@ def __init__(self, message: str): def is_available(self) -> bool: return False - def message(self) -> Optional[str]: + def message(self) -> str | None: return self._message @@ -49,8 +46,7 @@ def message(self) -> Optional[str]: @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) class AbstractAvailabilityStrategy(ABC): - """ - AbstractAvailabilityStrategy is an experimental interface developed as part of the Concurrent CDK. + """AbstractAvailabilityStrategy is an experimental interface developed as part of the Concurrent CDK. This interface is not yet stable and may change in the future. Use at your own risk. Why create a new interface instead of using the existing AvailabilityStrategy? @@ -59,8 +55,7 @@ class AbstractAvailabilityStrategy(ABC): @abstractmethod def check_availability(self, logger: logging.Logger) -> StreamAvailability: - """ - Checks stream availability. + """Checks stream availability. :param logger: logger object to use :return: A StreamAvailability object describing the stream's availability @@ -69,8 +64,7 @@ def check_availability(self, logger: logging.Logger) -> StreamAvailability: @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) class AlwaysAvailableAvailabilityStrategy(AbstractAvailabilityStrategy): - """ - An availability strategy that always indicates a stream is available. + """An availability strategy that always indicates a stream is available. This strategy is used to avoid breaking changes and serves as a soft deprecation of the availability strategy, allowing a smoother transition @@ -78,8 +72,7 @@ class AlwaysAvailableAvailabilityStrategy(AbstractAvailabilityStrategy): """ def check_availability(self, logger: logging.Logger) -> StreamAvailability: - """ - Checks stream availability. + """Checks stream availability. :param logger: logger object to use :return: A StreamAvailability object describing the stream's availability diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 15e9b59a..448c78cd 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import functools from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Protocol, Tuple +from collections.abc import Callable, Iterable, Mapping, MutableMapping +from typing import Any, Protocol from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository @@ -16,13 +18,12 @@ ) -def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any: +def _extract_value(mapping: Mapping[str, Any], path: list[str]) -> Any: return functools.reduce(lambda a, b: a[b], path, mapping) class GapType(Protocol): - """ - This is the representation of gaps between two cursor values. Examples: + """This is the representation of gaps between two cursor values. Examples: * if cursor values are datetimes, GapType is timedelta * if cursor values are integer, GapType will also be integer """ @@ -34,19 +35,19 @@ class CursorValueType(Protocol): """Protocol for annotating comparable types.""" @abstractmethod - def __lt__(self: "CursorValueType", other: "CursorValueType") -> bool: + def __lt__(self: CursorValueType, other: CursorValueType) -> bool: pass @abstractmethod - def __ge__(self: "CursorValueType", other: "CursorValueType") -> bool: + def __ge__(self: CursorValueType, other: CursorValueType) -> bool: pass @abstractmethod - def __add__(self: "CursorValueType", other: GapType) -> "CursorValueType": + def __add__(self: CursorValueType, other: GapType) -> CursorValueType: pass @abstractmethod - def __sub__(self: "CursorValueType", other: GapType) -> "CursorValueType": + def __sub__(self: CursorValueType, other: GapType) -> CursorValueType: pass @@ -68,29 +69,23 @@ def state(self) -> MutableMapping[str, Any]: ... @abstractmethod def observe(self, record: Record) -> None: - """ - Indicate to the cursor that the record has been emitted - """ - raise NotImplementedError() + """Indicate to the cursor that the record has been emitted""" + raise NotImplementedError @abstractmethod def close_partition(self, partition: Partition) -> None: - """ - Indicate to the cursor that the partition has been successfully processed - """ - raise NotImplementedError() + """Indicate to the cursor that the partition has been successfully processed""" + raise NotImplementedError @abstractmethod def ensure_at_least_one_state_emitted(self) -> None: - """ - State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per + """State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per stream. Hence, if no partitions are generated, this method needs to be called. """ - raise NotImplementedError() + raise NotImplementedError - def generate_slices(self) -> Iterable[Tuple[Any, Any]]: - """ - Default placeholder implementation of generate_slices. + def generate_slices(self) -> Iterable[tuple[Any, Any]]: + """Default placeholder implementation of generate_slices. Subclasses can override this method to provide actual behavior. """ yield from () @@ -102,7 +97,7 @@ class FinalStateCursor(Cursor): def __init__( self, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, message_repository: MessageRepository, ) -> None: self._stream_name = stream_name @@ -125,10 +120,7 @@ def close_partition(self, partition: Partition) -> None: pass def ensure_at_least_one_state_emitted(self) -> None: - """ - Used primarily for full refresh syncs that do not have a valid cursor value to emit at the end of a sync - """ - + """Used primarily for full refresh syncs that do not have a valid cursor value to emit at the end of a sync""" self._connector_state_manager.update_state_for_stream( self._stream_name, self._stream_namespace, self.state ) @@ -145,18 +137,18 @@ class ConcurrentCursor(Cursor): def __init__( self, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, stream_state: Any, message_repository: MessageRepository, connector_state_manager: ConnectorStateManager, connector_state_converter: AbstractStreamStateConverter, cursor_field: CursorField, - slice_boundary_fields: Optional[Tuple[str, str]], - start: Optional[CursorValueType], + slice_boundary_fields: tuple[str, str] | None, + start: CursorValueType | None, end_provider: Callable[[], CursorValueType], - lookback_window: Optional[GapType] = None, - slice_range: Optional[GapType] = None, - cursor_granularity: Optional[GapType] = None, + lookback_window: GapType | None = None, + slice_range: GapType | None = None, + cursor_granularity: GapType | None = None, ) -> None: self._stream_name = stream_name self._stream_namespace = stream_namespace @@ -184,12 +176,12 @@ def cursor_field(self) -> CursorField: return self._cursor_field @property - def slice_boundary_fields(self) -> Optional[Tuple[str, str]]: + def slice_boundary_fields(self) -> tuple[str, str] | None: return self._slice_boundary_fields def _get_concurrent_state( self, state: MutableMapping[str, Any] - ) -> Tuple[CursorValueType, MutableMapping[str, Any]]: + ) -> tuple[CursorValueType, MutableMapping[str, Any]]: if self._connector_state_converter.is_state_message_compatible(state): return ( self._start or self._connector_state_converter.zero_value, @@ -293,15 +285,13 @@ def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType ) from exception def ensure_at_least_one_state_emitted(self) -> None: - """ - The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be + """The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be called. """ self._emit_state_message() - def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: - """ - Generating slices based on a few parameters: + def generate_slices(self) -> Iterable[tuple[CursorValueType, CursorValueType]]: + """Generating slices based on a few parameters: * lookback_window: Buffer to remove from END_KEY of the highest slice * slice_range: Max difference between two slices. If the difference between two slices is greater, multiple slices will be created * start: `_split_per_slice_range` will clip any value to `self._start which means that: @@ -368,7 +358,7 @@ def _calculate_lower_boundary_of_last_slice( def _split_per_slice_range( self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool - ) -> Iterable[Tuple[CursorValueType, CursorValueType]]: + ) -> Iterable[tuple[CursorValueType, CursorValueType]]: if lower >= upper: return @@ -400,8 +390,7 @@ def _split_per_slice_range( stop_processing = True def _evaluate_upper_safely(self, lower: CursorValueType, step: GapType) -> CursorValueType: - """ - Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date + """Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code would have broken anyway. """ diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index eb94ebba..acef018f 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from functools import lru_cache +from collections.abc import Iterable, Mapping +from functools import cache from logging import Logger -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any from airbyte_cdk.models import AirbyteStream, SyncMode from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -24,11 +26,11 @@ def __init__( name: str, json_schema: Mapping[str, Any], availability_strategy: AbstractAvailabilityStrategy, - primary_key: List[str], - cursor_field: Optional[str], + primary_key: list[str], + cursor_field: str | None, logger: Logger, cursor: Cursor, - namespace: Optional[str] = None, + namespace: str | None = None, ) -> None: self._stream_partition_generator = partition_generator self._name = name @@ -48,17 +50,17 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: return self._namespace def check_availability(self) -> StreamAvailability: return self._availability_strategy.check_availability(self._logger) @property - def cursor_field(self) -> Optional[str]: + def cursor_field(self) -> str | None: return self._cursor_field - @lru_cache(maxsize=None) + @cache def get_json_schema(self) -> Mapping[str, Any]: return self._json_schema diff --git a/airbyte_cdk/sources/streams/concurrent/exceptions.py b/airbyte_cdk/sources/streams/concurrent/exceptions.py index a0cf699a..cc8b1149 100644 --- a/airbyte_cdk/sources/streams/concurrent/exceptions.py +++ b/airbyte_cdk/sources/streams/concurrent/exceptions.py @@ -1,14 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from typing import Any class ExceptionWithDisplayMessage(Exception): - """ - Exception that can be used to display a custom message to the user. - """ + """Exception that can be used to display a custom message to the user.""" def __init__(self, display_message: str, **kwargs: Any): super().__init__(**kwargs) diff --git a/airbyte_cdk/sources/streams/concurrent/helpers.py b/airbyte_cdk/sources/streams/concurrent/helpers.py index d839068a..b469c926 100644 --- a/airbyte_cdk/sources/streams/concurrent/helpers.py +++ b/airbyte_cdk/sources/streams/concurrent/helpers.py @@ -1,35 +1,30 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - -from typing import List, Optional, Union +from __future__ import annotations from airbyte_cdk.sources.streams import Stream def get_primary_key_from_stream( - stream_primary_key: Optional[Union[str, List[str], List[List[str]]]], -) -> List[str]: + stream_primary_key: str | list[str] | list[list[str]] | None, +) -> list[str]: if stream_primary_key is None: return [] - elif isinstance(stream_primary_key, str): + if isinstance(stream_primary_key, str): return [stream_primary_key] - elif isinstance(stream_primary_key, list): + if isinstance(stream_primary_key, list): if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key): return stream_primary_key # type: ignore # We verified all items in the list are strings - else: - raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") - else: - raise ValueError(f"Invalid type for primary key: {stream_primary_key}") + raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") + raise ValueError(f"Invalid type for primary key: {stream_primary_key}") -def get_cursor_field_from_stream(stream: Stream) -> Optional[str]: +def get_cursor_field_from_stream(stream: Stream) -> str | None: if isinstance(stream.cursor_field, list): if len(stream.cursor_field) > 1: raise ValueError( f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}" ) - elif len(stream.cursor_field) == 0: + if len(stream.cursor_field) == 0: return None - else: - return stream.cursor_field[0] - else: - return stream.cursor_field + return stream.cursor_field[0] + return stream.cursor_field diff --git a/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py b/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py index a4dd81f2..6508b8c3 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import time from queue import Queue @@ -14,9 +16,7 @@ class PartitionEnqueuer: - """ - Generates partitions from a partition generator and puts them in a queue. - """ + """Generates partitions from a partition generator and puts them in a queue.""" def __init__( self, @@ -24,8 +24,7 @@ def __init__( thread_pool_manager: ThreadPoolManager, sleep_time_in_seconds: float = 0.1, ) -> None: - """ - :param queue: The queue to put the partitions in. + """:param queue: The queue to put the partitions in. :param throttler: The throttler to use to throttle the partition generation. """ self._queue = queue @@ -33,8 +32,7 @@ def __init__( self._sleep_time_in_seconds = sleep_time_in_seconds def generate_partitions(self, stream: AbstractStream) -> None: - """ - Generate partitions from a partition generator and put them in a queue. + """Generate partitions from a partition generator and put them in a queue. When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. If an exception is encountered, the exception will be caught and put in the queue. This is very important because if we don't, the diff --git a/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte_cdk/sources/streams/concurrent/partition_reader.py index 3d23fd9c..fa44a1e4 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from queue import Queue from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException @@ -12,21 +14,16 @@ class PartitionReader: - """ - Generates records from a partition and puts them in a queue. - """ + """Generates records from a partition and puts them in a queue.""" _IS_SUCCESSFUL = True def __init__(self, queue: Queue[QueueItem]) -> None: - """ - :param queue: The queue to put the records in. - """ + """:param queue: The queue to put the records in.""" self._queue = queue def process_partition(self, partition: Partition) -> None: - """ - Process a partition and put the records in the output queue. + """Process a partition and put the records in the output queue. When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. If an exception is encountered, the exception will be caught and put in the queue. This is very important because if we don't, the diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py index 09f83d8f..fbf64b4a 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py @@ -1,30 +1,28 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.streams.concurrent.partitions.record import Record class Partition(ABC): - """ - A partition is responsible for reading a specific set of data from a source. - """ + """A partition is responsible for reading a specific set of data from a source.""" @abstractmethod def read(self) -> Iterable[Record]: - """ - Reads the data from the partition. + """Reads the data from the partition. :return: An iterable of records. """ pass @abstractmethod - def to_slice(self) -> Optional[Mapping[str, Any]]: - """ - Converts the partition to a slice that can be serialized and deserialized. + def to_slice(self) -> Mapping[str, Any] | None: + """Converts the partition to a slice that can be serialized and deserialized. Note: it would have been interesting to have a type of `Mapping[str, Comparable]` to simplify typing but some slices can have nested values ([example](https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L584-L596)) @@ -34,30 +32,25 @@ def to_slice(self) -> Optional[Mapping[str, Any]]: @abstractmethod def stream_name(self) -> str: - """ - Returns the name of the stream that this partition is reading from. + """Returns the name of the stream that this partition is reading from. :return: The name of the stream. """ pass @abstractmethod def close(self) -> None: - """ - Closes the partition. - """ + """Closes the partition.""" pass @abstractmethod def is_closed(self) -> bool: - """ - Returns whether the partition is closed. + """Returns whether the partition is closed. :return: """ pass @abstractmethod def __hash__(self) -> int: - """ - Returns a hash of the partition. + """Returns a hash of the partition. Partitions must be hashable so that they can be used as keys in a dictionary. """ diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py index eff97856..471af4a6 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py @@ -1,9 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterable +from collections.abc import Iterable from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition @@ -11,8 +12,7 @@ class PartitionGenerator(ABC): @abstractmethod def generate(self) -> Iterable[Partition]: - """ - Generates partitions for a given sync mode. + """Generates partitions for a given sync mode. :return: An iterable of partitions """ pass diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/record.py b/airbyte_cdk/sources/streams/concurrent/partitions/record.py index e67dc656..d6877de9 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/record.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/record.py @@ -1,29 +1,30 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING, Any, Mapping if TYPE_CHECKING: from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition class Record: - """ - Represents a record read from a stream. - """ + """Represents a record read from a stream.""" def __init__( self, data: Mapping[str, Any], - partition: "Partition", + partition: Partition, is_file_transfer_message: bool = False, ): self.data = data self.partition = partition self.is_file_transfer_message = is_file_transfer_message - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, Record): return False return ( diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index 7abebe07..13931984 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -1,8 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Union +from typing import Union from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, @@ -12,19 +13,16 @@ class PartitionCompleteSentinel: - """ - A sentinel object indicating all records for a partition were produced. + """A sentinel object indicating all records for a partition were produced. Includes a pointer to the partition that was processed. """ def __init__(self, partition: Partition, is_successful: bool = True): - """ - :param partition: The partition that was processed - """ + """:param partition: The partition that was processed""" self.partition = partition self.is_successful = is_successful - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, PartitionCompleteSentinel): return self.partition == other.partition return False diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py index 1b477976..7a377594 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import MutableMapping from enum import Enum -from typing import TYPE_CHECKING, Any, List, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Any + if TYPE_CHECKING: from airbyte_cdk.sources.streams.concurrent.cursor import CursorField @@ -31,10 +34,9 @@ def __init__(self, is_sequential_state: bool = True): self._is_sequential_state = is_sequential_state def convert_to_state_message( - self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any] + self, cursor_field: CursorField, stream_state: MutableMapping[str, Any] ) -> MutableMapping[str, Any]: - """ - Convert the state message from the concurrency-compatible format to the stream's original format. + """Convert the state message from the concurrency-compatible format to the stream's original format. e.g. { "created": "2021-01-18T21:18:20.000Z" } @@ -47,13 +49,10 @@ def convert_to_state_message( {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)} ) return legacy_state or {} - else: - return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) + return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) - def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any: - """ - Get the latest time before which all records have been processed. - """ + def _get_latest_complete_time(self, slices: list[MutableMapping[str, Any]]) -> Any: + """Get the latest time before which all records have been processed.""" if not slices: raise RuntimeError( "Expected at least one slice but there were none. This is unexpected; please contact Support." @@ -64,9 +63,7 @@ def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> A return first_interval.get("most_recent_cursor_value") or first_interval[self.START_KEY] def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: - """ - Perform any transformations needed for compatibility with the converter. - """ + """Perform any transformations needed for compatibility with the converter.""" for stream_slice in state.get("slices", []): stream_slice[self.START_KEY] = self._from_state_message(stream_slice[self.START_KEY]) stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY]) @@ -75,9 +72,7 @@ def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, An def serialize( self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType ) -> MutableMapping[str, Any]: - """ - Perform any transformations needed for compatibility with the converter. - """ + """Perform any transformations needed for compatibility with the converter.""" serialized_slices = [] for stream_slice in state.get("slices", []): serialized_slice = { @@ -100,12 +95,11 @@ def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool: @abstractmethod def convert_from_sequential_state( self, - cursor_field: "CursorField", # to deprecate as it is only needed for sequential state + cursor_field: CursorField, # to deprecate as it is only needed for sequential state stream_state: MutableMapping[str, Any], - start: Optional[Any], - ) -> Tuple[Any, MutableMapping[str, Any]]: - """ - Convert the state message to the format required by the ConcurrentCursor. + start: Any | None, + ) -> tuple[Any, MutableMapping[str, Any]]: + """Convert the state message to the format required by the ConcurrentCursor. e.g. { @@ -119,16 +113,13 @@ def convert_from_sequential_state( @abstractmethod def increment(self, value: Any) -> Any: - """ - Increment a timestamp by a single unit. - """ + """Increment a timestamp by a single unit.""" ... def merge_intervals( - self, intervals: List[MutableMapping[str, Any]] - ) -> List[MutableMapping[str, Any]]: - """ - Compute and return a list of merged intervals. + self, intervals: list[MutableMapping[str, Any]] + ) -> list[MutableMapping[str, Any]]: + """Compute and return a list of merged intervals. Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the `increment` method) than the end time of the first interval. @@ -164,9 +155,7 @@ def merge_intervals( @abstractmethod def parse_value(self, value: Any) -> Any: - """ - Parse the value of the cursor field into a comparable value. - """ + """Parse the value of the cursor field into a comparable value.""" ... @property diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py index 3ff22c09..7cfaaf78 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod +from collections.abc import Callable, MutableMapping from datetime import datetime, timedelta, timezone -from typing import Any, Callable, List, MutableMapping, Optional, Tuple +from typing import Any import pendulum +from pendulum.datetime import DateTime # FIXME We would eventually like the Concurrent package do be agnostic of the declarative package. However, this is a breaking change and # the goal in the short term is only to fix the issue we are seeing for source-declarative-manifest. @@ -16,7 +19,6 @@ AbstractStreamStateConverter, ConcurrencyCompatibleStateType, ) -from pendulum.datetime import DateTime class DateTimeStreamStateConverter(AbstractStreamStateConverter): @@ -48,9 +50,7 @@ def parse_timestamp(self, timestamp: Any) -> datetime: ... def output_format(self, timestamp: datetime) -> Any: ... def parse_value(self, value: Any) -> Any: - """ - Parse the value of the cursor field into a comparable value. - """ + """Parse the value of the cursor field into a comparable value.""" return self.parse_timestamp(value) def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: @@ -60,10 +60,9 @@ def convert_from_sequential_state( self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], - start: Optional[datetime], - ) -> Tuple[datetime, MutableMapping[str, Any]]: - """ - Convert the state message to the format required by the ConcurrentCursor. + start: datetime | None, + ) -> tuple[datetime, MutableMapping[str, Any]]: + """Convert the state message to the format required by the ConcurrentCursor. e.g. { @@ -95,7 +94,7 @@ def _get_sync_start( self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], - start: Optional[datetime], + start: datetime | None, ) -> datetime: sync_start = start if start is not None else self.zero_value prev_sync_low_water_mark = ( @@ -105,13 +104,11 @@ def _get_sync_start( ) if prev_sync_low_water_mark and prev_sync_low_water_mark >= sync_start: return prev_sync_low_water_mark - else: - return sync_start + return sync_start class EpochValueConcurrentStreamStateConverter(DateTimeStreamStateConverter): - """ - e.g. + """e.g. { "created": 1617030403 } => { @@ -141,8 +138,7 @@ def parse_timestamp(self, timestamp: int) -> datetime: class IsoMillisConcurrentStreamStateConverter(DateTimeStreamStateConverter): - """ - e.g. + """e.g. { "created": "2021-01-18T21:18:20.000Z" } => { @@ -157,7 +153,7 @@ class IsoMillisConcurrentStreamStateConverter(DateTimeStreamStateConverter): _zero_value = "0001-01-01T00:00:00.000Z" def __init__( - self, is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None + self, is_sequential_state: bool = True, cursor_granularity: timedelta | None = None ): super().__init__(is_sequential_state=is_sequential_state) self._cursor_granularity = cursor_granularity or timedelta(milliseconds=1) @@ -178,23 +174,22 @@ def parse_timestamp(self, timestamp: str) -> datetime: class CustomFormatConcurrentStreamStateConverter(IsoMillisConcurrentStreamStateConverter): - """ - Datetime State converter that emits state according to the supplied datetime format. The converter supports reading + """Datetime State converter that emits state according to the supplied datetime format. The converter supports reading incoming state in any valid datetime format via Pendulum. """ def __init__( self, datetime_format: str, - input_datetime_formats: Optional[List[str]] = None, + input_datetime_formats: list[str] | None = None, is_sequential_state: bool = True, - cursor_granularity: Optional[timedelta] = None, + cursor_granularity: timedelta | None = None, ): super().__init__( is_sequential_state=is_sequential_state, cursor_granularity=cursor_granularity ) self._datetime_format = datetime_format - self._input_datetime_formats = input_datetime_formats if input_datetime_formats else [] + self._input_datetime_formats = input_datetime_formats or [] self._input_datetime_formats += [self._datetime_format] self._parser = DatetimeParser() diff --git a/airbyte_cdk/sources/streams/core.py b/airbyte_cdk/sources/streams/core.py index 90925c4c..68a809fa 100644 --- a/airbyte_cdk/sources/streams/core.py +++ b/airbyte_cdk/sources/streams/core.py @@ -1,16 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import copy import inspect import itertools import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Mapping, MutableMapping from dataclasses import dataclass -from functools import cached_property, lru_cache -from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Union +from functools import cache, cached_property +from typing import Any, Union + +from deprecated import deprecated -import airbyte_cdk.sources.utils.casing as casing from airbyte_cdk.models import ( AirbyteMessage, AirbyteStream, @@ -30,12 +34,13 @@ ResumableFullRefreshCheckpointReader, ) from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.sources.utils import casing # list of all possible HTTP methods which can be used for sending of request bodies from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, ResourceSchemaLoader from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer -from deprecated import deprecated + # A stream's read method can return one of the following types: # Mapping[str, Any]: The content of an AirbyteRecordMessage @@ -52,8 +57,7 @@ def package_name_from_class(cls: object) -> str: module = inspect.getmodule(cls) if module is not None: return module.__name__.split(".")[0] - else: - raise ValueError(f"Could not find package name for class {cls}") + raise ValueError(f"Could not find package name for class {cls}") class CheckpointMixin(ABC): @@ -121,11 +125,9 @@ class StreamClassification: action="ignore", ) class Stream(ABC): - """ - Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol. - """ + """Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol.""" - _configured_json_schema: Optional[Dict[str, Any]] = None + _configured_json_schema: dict[str, Any] | None = None _exit_on_rate_limit: bool = False # Use self.logger in subclasses to log any messages @@ -136,20 +138,17 @@ def logger(self) -> logging.Logger: # TypeTransformer object to perform output data transformation transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform) - cursor: Optional[Cursor] = None + cursor: Cursor | None = None has_multiple_slices = False @cached_property def name(self) -> str: - """ - :return: Stream name. By default this is the implementing class name, but it can be overridden as needed. - """ + """:return: Stream name. By default this is the implementing class name, but it can be overridden as needed.""" return casing.camel_to_snake(self.__class__.__name__) - def get_error_display_message(self, exception: BaseException) -> Optional[str]: - """ - Retrieves the user-friendly display message that corresponds to an exception. + def get_error_display_message(self, exception: BaseException) -> str | None: + """Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. The default implementation of this method does not return user-friendly messages for any exception type, but it should be overriden as needed. @@ -193,7 +192,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o if slice_logger.should_log_slice_message(logger): yield slice_logger.create_slice_log_message(next_slice) records = self.read_records( - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_slice=next_slice, stream_state=stream_state, cursor_field=cursor_field or None, @@ -256,13 +255,11 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager) yield airbyte_state_message - def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterable[StreamData]: - """ - Helper method that performs a read on a stream with an optional state and emits records. If the parent stream supports + def read_only_records(self, state: Mapping[str, Any] | None = None) -> Iterable[StreamData]: + """Helper method that performs a read on a stream with an optional state and emits records. If the parent stream supports incremental, this operation does not update the stream's internal state (if it uses the modern state setter/getter) or emit state messages. """ - configured_stream = ConfiguredAirbyteStream( stream=AirbyteStream( name=self.name, @@ -288,18 +285,15 @@ def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterab def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: - """ - This method should be overridden by subclasses to read records based on the inputs - """ + """This method should be overridden by subclasses to read records based on the inputs""" - @lru_cache(maxsize=None) + @cache def get_json_schema(self) -> Mapping[str, Any]: - """ - :return: A dict of the JSON schema representing this stream. + """:return: A dict of the JSON schema representing this stream. The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property. Override as needed. @@ -332,15 +326,12 @@ def as_airbyte_stream(self) -> AirbyteStream: @property def supports_incremental(self) -> bool: - """ - :return: True if this stream supports incrementally reading data - """ + """:return: True if this stream supports incrementally reading data""" return len(self._wrapped_cursor_field()) > 0 @property def is_resumable(self) -> bool: - """ - :return: True if this stream allows the checkpointing of sync progress and can resume from it on subsequent attempts. + """:return: True if this stream allows the checkpointing of sync progress and can resume from it on subsequent attempts. This differs from supports_incremental because certain kinds of streams like those supporting resumable full refresh can checkpoint progress in between attempts for improved fault tolerance. However, they will start from the beginning on the next sync job. @@ -352,38 +343,33 @@ def is_resumable(self) -> bool: # to structure stream state in a very specific way. We also can't check for issubclass(HttpSubStream) because # not all substreams implement the interface and it would be a circular dependency so we use parent as a surrogate return False - elif hasattr(type(self), "state") and getattr(type(self), "state").fset is not None: + if hasattr(type(self), "state") and type(self).state.fset is not None: # Modern case where a stream manages state using getter/setter return True - else: - # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if - # the stream's get_updated_state() differs from the Stream class and therefore has been overridden - return type(self).get_updated_state != Stream.get_updated_state + # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if + # the stream's get_updated_state() differs from the Stream class and therefore has been overridden + return type(self).get_updated_state != Stream.get_updated_state - def _wrapped_cursor_field(self) -> List[str]: + def _wrapped_cursor_field(self) -> list[str]: return [self.cursor_field] if isinstance(self.cursor_field, str) else self.cursor_field @property - def cursor_field(self) -> Union[str, List[str]]: - """ - Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. + def cursor_field(self) -> str | list[str]: + """Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ return [] @property - def namespace(self) -> Optional[str]: - """ - Override to return the namespace of this stream, e.g. the Postgres schema which this stream will emit records for. + def namespace(self) -> str | None: + """Override to return the namespace of this stream, e.g. the Postgres schema which this stream will emit records for. :return: A string containing the name of the namespace. """ return None @property def source_defined_cursor(self) -> bool: - """ - Return False if the cursor can be configured by the user. - """ + """Return False if the cursor can be configured by the user.""" return True @property @@ -398,21 +384,19 @@ def exit_on_rate_limit(self, value: bool) -> None: @property @abstractmethod - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: - """ - :return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields. - If the stream has no primary keys, return None. + def primary_key(self) -> str | list[str] | list[list[str]] | None: + """:return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields. + If the stream has no primary keys, return None. """ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: - """ - Override to define the slices for this stream. See the stream slicing section of the docs for more information. + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: + """Override to define the slices for this stream. See the stream slicing section of the docs for more information. :param sync_mode: :param cursor_field: @@ -422,9 +406,8 @@ def stream_slices( yield StreamSlice(partition={}, cursor_slice={}) @property - def state_checkpoint_interval(self) -> Optional[int]: - """ - Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading + def state_checkpoint_interval(self) -> int | None: + """Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading 100 records, then 200, 300, etc.. A good default value is 1000 although your mileage may vary depending on the underlying data source. Checkpointing a stream avoids re-reading records in the case a sync is failed or cancelled. @@ -451,9 +434,8 @@ def get_updated_state( """ return {} - def get_cursor(self) -> Optional[Cursor]: - """ - A Cursor is an interface that a stream can implement to manage how its internal state is read and updated while + def get_cursor(self) -> Cursor | None: + """A Cursor is an interface that a stream can implement to manage how its internal state is read and updated while reading records. Historically, Python connectors had no concept of a cursor to manage state. Python streams need to define a cursor implementation and override this method to manage state through a Cursor. """ @@ -462,13 +444,13 @@ def get_cursor(self) -> Optional[Cursor]: def _get_checkpoint_reader( self, logger: logging.Logger, - cursor_field: Optional[List[str]], + cursor_field: list[str] | None, sync_mode: SyncMode, stream_state: MutableMapping[str, Any], ) -> CheckpointReader: mappings_or_slices = self.stream_slices( cursor_field=cursor_field, - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_state=stream_state, ) @@ -501,38 +483,35 @@ def _get_checkpoint_reader( return LegacyCursorBasedCheckpointReader( stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=True ) - elif cursor: + if cursor: return CursorBasedCheckpointReader( stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH, ) - elif checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH: + if checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH: # Resumable full refresh readers rely on the stream state dynamically being updated during pagination and does # not iterate over a static set of slices. return ResumableFullRefreshCheckpointReader(stream_state=stream_state) - elif checkpoint_mode == CheckpointMode.INCREMENTAL: + if checkpoint_mode == CheckpointMode.INCREMENTAL: return IncrementalCheckpointReader( stream_slices=slices_iterable_copy, stream_state=stream_state ) - else: - return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy) + return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy) @property def _checkpoint_mode(self) -> CheckpointMode: if self.is_resumable and len(self._wrapped_cursor_field()) > 0: return CheckpointMode.INCREMENTAL - elif self.is_resumable: + if self.is_resumable: return CheckpointMode.RESUMABLE_FULL_REFRESH - else: - return CheckpointMode.FULL_REFRESH + return CheckpointMode.FULL_REFRESH @staticmethod def _classify_stream( - mappings_or_slices: Iterator[Optional[Union[Mapping[str, Any], StreamSlice]]], + mappings_or_slices: Iterator[Mapping[str, Any] | StreamSlice | None], ) -> StreamClassification: - """ - This is a bit of a crazy solution, but also the only way we can detect certain attributes about the stream since Python + """This is a bit of a crazy solution, but also the only way we can detect certain attributes about the stream since Python streams do not follow consistent implementation patterns. We care about the following two attributes: - is_substream: Helps to incrementally release changes since substreams w/ parents are much more complicated. Also helps de-risk the release of changes that might impact all connectors @@ -584,9 +563,7 @@ def _classify_stream( ) def log_stream_sync_configuration(self) -> None: - """ - Logs the configuration of this stream. - """ + """Logs the configuration of this stream.""" self.logger.debug( f"Syncing stream instance: {self.name}", extra={ @@ -597,17 +574,15 @@ def log_stream_sync_configuration(self) -> None: @staticmethod def _wrapped_primary_key( - keys: Optional[Union[str, List[str], List[List[str]]]], - ) -> Optional[List[List[str]]]: - """ - :return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object. - """ + keys: str | list[str] | list[list[str]] | None, + ) -> list[list[str]] | None: + """:return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object.""" if not keys: return None if isinstance(keys, str): return [[keys]] - elif isinstance(keys, list): + if isinstance(keys, list): wrapped_keys = [] for component in keys: if isinstance(component, str): @@ -617,18 +592,15 @@ def _wrapped_primary_key( else: raise ValueError(f"Element must be either list or str. Got: {type(component)}") return wrapped_keys - else: - raise ValueError(f"Element must be either list or str. Got: {type(keys)}") + raise ValueError(f"Element must be either list or str. Got: {type(keys)}") def _observe_state( - self, checkpoint_reader: CheckpointReader, stream_state: Optional[Mapping[str, Any]] = None + self, checkpoint_reader: CheckpointReader, stream_state: Mapping[str, Any] | None = None ) -> None: - """ - Convenience method that attempts to read the Stream's state using the recommended way of connector's managing their + """Convenience method that attempts to read the Stream's state using the recommended way of connector's managing their own state via state setter/getter. But if we get back an AttributeError, then the legacy Stream.get_updated_state() method is used as a fallback method. """ - # This is an inversion of the original logic that used to try state getter/setters first. As part of the work to # automatically apply resumable full refresh to all streams, all HttpStream classes implement default state # getter/setter methods, we should default to only using the incoming stream_state parameter value is {} which @@ -650,29 +622,27 @@ def _checkpoint_state( # type: ignore # ignoring typing for ConnectorStateMana stream_state: Mapping[str, Any], state_manager, ) -> AirbyteMessage: - # todo: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want + # TODO: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want # to reduce changes right now and this would span concurrent as well state_manager.update_state_for_stream(self.name, self.namespace, stream_state) return state_manager.create_state_message(self.name, self.namespace) @property - def configured_json_schema(self) -> Optional[Dict[str, Any]]: - """ - This property is set from the read method. + def configured_json_schema(self) -> dict[str, Any] | None: + """This property is set from the read method. :return Optional[Dict]: JSON schema from configured catalog if provided, otherwise None. """ return self._configured_json_schema @configured_json_schema.setter - def configured_json_schema(self, json_schema: Dict[str, Any]) -> None: + def configured_json_schema(self, json_schema: dict[str, Any]) -> None: self._configured_json_schema = self._filter_schema_invalid_properties(json_schema) def _filter_schema_invalid_properties( - self, configured_catalog_json_schema: Dict[str, Any] - ) -> Dict[str, Any]: - """ - Filters the properties in json_schema that are not present in the stream schema. + self, configured_catalog_json_schema: dict[str, Any] + ) -> dict[str, Any]: + """Filters the properties in json_schema that are not present in the stream schema. Configured Schemas can have very old fields, so we need to housekeeping ourselves. """ configured_schema: Any = configured_catalog_json_schema.get("properties", {}) diff --git a/airbyte_cdk/sources/streams/http/availability_strategy.py b/airbyte_cdk/sources/streams/http/availability_strategy.py index 494fcf15..7679eefe 100644 --- a/airbyte_cdk/sources/streams/http/availability_strategy.py +++ b/airbyte_cdk/sources/streams/http/availability_strategy.py @@ -1,25 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import typing -from typing import Optional, Tuple from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy from airbyte_cdk.utils.traced_exception import AirbyteTracedException + if typing.TYPE_CHECKING: from airbyte_cdk.sources import Source class HttpAvailabilityStrategy(AvailabilityStrategy): def check_availability( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: - """ - Check stream availability by attempting to read the first record of the + self, stream: Stream, logger: logging.Logger, source: Source | None = None + ) -> tuple[bool, str | None]: + """Check stream availability by attempting to read the first record of the stream. :param stream: stream @@ -30,7 +30,7 @@ def check_availability( for some reason and the str should describe what went wrong and how to resolve the unavailability, if possible. """ - reason: Optional[str] + reason: str | None try: # Some streams need a stream slice to read records (e.g. if they have a SubstreamPartitionRouter) # Streams that don't need a stream slice will return `None` as their first stream slice. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py index 6ed82179..e1e743b0 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py @@ -1,9 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Union import requests @@ -12,11 +12,10 @@ class BackoffStrategy(ABC): @abstractmethod def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: - """ - Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header. + ) -> float | None: + """Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header. This method is called only if should_backoff() returns True for the input request. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py index 2c3e10ad..db8efc89 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py @@ -1,7 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - - -from typing import Optional, Union +from __future__ import annotations import requests @@ -11,7 +9,7 @@ class DefaultBackoffStrategy(BackoffStrategy): def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: return None diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py index fa8864db..62d033db 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py @@ -1,17 +1,20 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Mapping, Type, Union +from collections.abc import Mapping + +from requests.exceptions import InvalidSchema, InvalidURL, RequestException from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( ErrorResolution, ResponseAction, ) -from requests.exceptions import InvalidSchema, InvalidURL, RequestException -DEFAULT_ERROR_MAPPING: Mapping[Union[int, str, Type[Exception]], ErrorResolution] = { + +DEFAULT_ERROR_MAPPING: Mapping[int | str | type[Exception], ErrorResolution] = { InvalidSchema: ErrorResolution( response_action=ResponseAction.FAIL, failure_type=FailureType.config_error, diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py index b231e72e..7c22555d 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Union import requests @@ -9,32 +9,23 @@ class ErrorHandler(ABC): - """ - Abstract base class to determine how to handle a failed HTTP request. - """ + """Abstract base class to determine how to handle a failed HTTP request.""" @property @abstractmethod - def max_retries(self) -> Optional[int]: - """ - The maximum number of retries to attempt before giving up. - """ + def max_retries(self) -> int | None: + """The maximum number of retries to attempt before giving up.""" pass @property @abstractmethod - def max_time(self) -> Optional[int]: - """ - The maximum amount of time in seconds to retry before giving up. - """ + def max_time(self) -> int | None: + """The maximum amount of time in seconds to retry before giving up.""" pass @abstractmethod - def interpret_response( - self, response: Optional[Union[requests.Response, Exception]] - ) -> ErrorResolution: - """ - Interpret the response or exception and return the corresponding response action, failure type, and error message. + def interpret_response(self, response: requests.Response | Exception | None) -> ErrorResolution: + """Interpret the response or exception and return the corresponding response action, failure type, and error message. :param response: The HTTP response object or exception raised during the request. :return: A tuple containing the response action, failure type, and error message. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py index 966fe93a..e94fe452 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py @@ -1,18 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional import requests class ErrorMessageParser(ABC): @abstractmethod - def parse_response_error_message(self, response: requests.Response) -> Optional[str]: - """ - Parse error message from response. + def parse_response_error_message(self, response: requests.Response) -> str | None: + """Parse error message from response. :param response: response received for the request :return: error message """ diff --git a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py index f18e3db2..0a59c22f 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py @@ -1,12 +1,14 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Mapping from datetime import timedelta -from typing import Mapping, Optional, Union import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( DEFAULT_ERROR_MAPPING, @@ -22,12 +24,11 @@ class HttpStatusErrorHandler(ErrorHandler): def __init__( self, logger: logging.Logger, - error_mapping: Optional[Mapping[Union[int, str, type[Exception]], ErrorResolution]] = None, + error_mapping: Mapping[int | str | type[Exception], ErrorResolution] | None = None, max_retries: int = 5, max_time: timedelta = timedelta(seconds=600), ) -> None: - """ - Initialize the HttpStatusErrorHandler. + """Initialize the HttpStatusErrorHandler. :param error_mapping: Custom error mappings to extend or override the default mappings. """ @@ -37,41 +38,36 @@ def __init__( self._max_time = int(max_time.total_seconds()) @property - def max_retries(self) -> Optional[int]: + def max_retries(self) -> int | None: return self._max_retries @property - def max_time(self) -> Optional[int]: + def max_time(self) -> int | None: return self._max_time def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + self, response_or_exception: requests.Response | Exception | None = None ) -> ErrorResolution: - """ - Interpret the response and return the corresponding response action, failure type, and error message. + """Interpret the response and return the corresponding response action, failure type, and error message. :param response: The HTTP response object. :return: A tuple containing the response action, failure type, and error message. """ - if isinstance(response_or_exception, Exception): - mapped_error: Optional[ErrorResolution] = self._error_mapping.get( + mapped_error: ErrorResolution | None = self._error_mapping.get( response_or_exception.__class__ ) if mapped_error is not None: return mapped_error - else: - self._logger.error( - f"Unexpected exception in error handler: {response_or_exception}" - ) - return ErrorResolution( - response_action=ResponseAction.RETRY, - failure_type=FailureType.system_error, - error_message=f"Unexpected exception in error handler: {response_or_exception}", - ) + self._logger.error(f"Unexpected exception in error handler: {response_or_exception}") + return ErrorResolution( + response_action=ResponseAction.RETRY, + failure_type=FailureType.system_error, + error_message=f"Unexpected exception in error handler: {response_or_exception}", + ) - elif isinstance(response_or_exception, requests.Response): + if isinstance(response_or_exception, requests.Response): if response_or_exception.status_code is None: self._logger.error("Response does not include an HTTP status code.") return ErrorResolution( @@ -93,17 +89,15 @@ def interpret_response( if mapped_error is not None: return mapped_error - else: - self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'") - return ErrorResolution( - response_action=ResponseAction.RETRY, - failure_type=FailureType.system_error, - error_message=f"Unexpected HTTP Status Code in error handler: {error_key}", - ) - else: - self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'") return ErrorResolution( - response_action=ResponseAction.FAIL, + response_action=ResponseAction.RETRY, failure_type=FailureType.system_error, - error_message=f"Received unexpected response type: {type(response_or_exception)}", + error_message=f"Unexpected HTTP Status Code in error handler: {error_key}", ) + self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message=f"Received unexpected response type: {type(response_or_exception)}", + ) diff --git a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py index 3ca31ec5..008fbad0 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py @@ -1,22 +1,22 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import Optional +from __future__ import annotations import requests + from airbyte_cdk.sources.streams.http.error_handlers import ErrorMessageParser from airbyte_cdk.sources.utils.types import JsonType class JsonErrorMessageParser(ErrorMessageParser): - def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]: + def _try_get_error(self, value: JsonType | None) -> str | None: if isinstance(value, str): return value - elif isinstance(value, list): + if isinstance(value, list): errors_in_value = [self._try_get_error(v) for v in value] return ", ".join(v for v in errors_in_value if v is not None) - elif isinstance(value, dict): + if isinstance(value, dict): new_value = ( value.get("message") or value.get("messages") @@ -34,9 +34,8 @@ def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]: return self._try_get_error(new_value) return None - def parse_response_error_message(self, response: requests.Response) -> Optional[str]: - """ - Parses the raw response object from a failed request into a user-friendly error message. + def parse_response_error_message(self, response: requests.Response) -> str | None: + """Parses the raw response object from a failed request into a user-friendly error message. :param response: :return: A user-friendly message that indicates the cause of the error diff --git a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py index aca13a8c..0671323c 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py @@ -1,13 +1,14 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from dataclasses import dataclass from enum import Enum -from typing import Optional, Union import requests +from requests import HTTPError + from airbyte_cdk.models import FailureType from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets -from requests import HTTPError class ResponseAction(Enum): @@ -20,13 +21,13 @@ class ResponseAction(Enum): @dataclass class ErrorResolution: - response_action: Optional[ResponseAction] = None - failure_type: Optional[FailureType] = None - error_message: Optional[str] = None + response_action: ResponseAction | None = None + failure_type: FailureType | None = None + error_message: str | None = None def _format_exception_error_message(exception: Exception) -> str: - return f"{type(exception).__name__}: {str(exception)}" + return f"{type(exception).__name__}: {exception!s}" def _format_response_error_message(response: requests.Response) -> str: @@ -34,7 +35,7 @@ def _format_response_error_message(response: requests.Response) -> str: response.raise_for_status() except HTTPError as exception: return filter_secrets( - f"Response was not ok: `{str(exception)}`. Response content is: {response.text}" + f"Response was not ok: `{exception!s}`. Response content is: {response.text}" ) # We purposefully do not add the response.content because the response is "ok" so there might be sensitive information in the payload. # Feel free the @@ -42,7 +43,7 @@ def _format_response_error_message(response: requests.Response) -> str: def create_fallback_error_resolution( - response_or_exception: Optional[Union[requests.Response, Exception]], + response_or_exception: requests.Response | Exception | None, ) -> ErrorResolution: if response_or_exception is None: # We do not expect this case to happen but if it does, it would be good to understand the cause and improve the error message diff --git a/airbyte_cdk/sources/streams/http/exceptions.py b/airbyte_cdk/sources/streams/http/exceptions.py index ee468762..c904fe0b 100644 --- a/airbyte_cdk/sources/streams/http/exceptions.py +++ b/airbyte_cdk/sources/streams/http/exceptions.py @@ -1,9 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - - -from typing import Optional, Union +from __future__ import annotations import requests @@ -12,7 +10,7 @@ class BaseBackoffException(requests.exceptions.HTTPError): def __init__( self, request: requests.PreparedRequest, - response: Optional[Union[requests.Response, Exception]], + response: requests.Response | Exception | None, error_message: str = "", ): if isinstance(response, requests.Response): @@ -27,25 +25,20 @@ def __init__( class RequestBodyException(Exception): - """ - Raised when there are issues in configuring a request body - """ + """Raised when there are issues in configuring a request body""" class UserDefinedBackoffException(BaseBackoffException): - """ - An exception that exposes how long it attempted to backoff - """ + """An exception that exposes how long it attempted to backoff""" def __init__( self, - backoff: Union[int, float], + backoff: int | float, request: requests.PreparedRequest, - response: Optional[Union[requests.Response, Exception]], + response: requests.Response | Exception | None, error_message: str = "", ): - """ - :param backoff: how long to backoff in seconds + """:param backoff: how long to backoff in seconds :param request: the request that triggered this backoff exception :param response: the response that triggered the backoff exception """ diff --git a/airbyte_cdk/sources/streams/http/http.py b/airbyte_cdk/sources/streams/http/http.py index f9731517..c3498706 100644 --- a/airbyte_cdk/sources/streams/http/http.py +++ b/airbyte_cdk/sources/streams/http/http.py @@ -1,14 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Mapping, MutableMapping from datetime import timedelta -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any from urllib.parse import urljoin import requests +from deprecated import deprecated +from requests.auth import AuthBase + from airbyte_cdk.models import AirbyteMessage, FailureType, SyncMode from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.message.repository import InMemoryMessageRepository @@ -33,26 +38,21 @@ from airbyte_cdk.sources.streams.http.http_client import HttpClient from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.sources.utils.types import JsonType -from deprecated import deprecated -from requests.auth import AuthBase + # list of all possible HTTP methods which can be used for sending of request bodies BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") class HttpStream(Stream, CheckpointMixin, ABC): - """ - Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API. - """ + """Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API.""" source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table) - page_size: Optional[int] = ( + page_size: int | None = ( None # Use this variable to define page size for API http requests with pagination support ) - def __init__( - self, authenticator: Optional[AuthBase] = None, api_budget: Optional[APIBudget] = None - ): + def __init__(self, authenticator: AuthBase | None = None, api_budget: APIBudget | None = None): self._exit_on_rate_limit: bool = False self._http_client = HttpClient( name=self.name, @@ -79,9 +79,7 @@ def __init__( @property def exit_on_rate_limit(self) -> bool: - """ - :return: False if the stream will retry endlessly when rate limited - """ + """:return: False if the stream will retry endlessly when rate limited""" return self._exit_on_rate_limit @exit_on_rate_limit.setter @@ -90,16 +88,14 @@ def exit_on_rate_limit(self, value: bool) -> None: @property def cache_filename(self) -> str: - """ - Override if needed. Return the name of cache file + """Override if needed. Return the name of cache file Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. """ return f"{self.name}.sqlite" @property def use_cache(self) -> bool: - """ - Override if needed. If True, all records will be cached. + """Override if needed. If True, all records will be cached. Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. """ return False @@ -107,15 +103,11 @@ def use_cache(self) -> bool: @property @abstractmethod def url_base(self) -> str: - """ - :return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" - """ + """:return: URL base for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "https://myapi.com/v1/" """ @property def http_method(self) -> str: - """ - Override if needed. See get_request_data/get_request_json if using POST/PUT/PATCH. - """ + """Override if needed. See get_request_data/get_request_json if using POST/PUT/PATCH.""" return "GET" @property @@ -124,9 +116,7 @@ def http_method(self) -> str: reason="You should set error_handler explicitly in HttpStream.get_error_handler() instead.", ) def raise_on_http_errors(self) -> bool: - """ - Override if needed. If set to False, allows opting-out of raising HTTP code exception. - """ + """Override if needed. If set to False, allows opting-out of raising HTTP code exception.""" return True @property @@ -134,10 +124,8 @@ def raise_on_http_errors(self) -> bool: version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", ) - def max_retries(self) -> Union[int, None]: - """ - Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit. - """ + def max_retries(self) -> int | None: + """Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit.""" return 5 @property @@ -145,10 +133,8 @@ def max_retries(self) -> Union[int, None]: version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", ) - def max_time(self) -> Union[int, None]: - """ - Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit. - """ + def max_time(self) -> int | None: + """Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit.""" return 60 * 10 @property @@ -157,15 +143,12 @@ def max_time(self) -> Union[int, None]: reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", ) def retry_factor(self) -> float: - """ - Override if needed. Specifies factor for backoff policy. - """ + """Override if needed. Specifies factor for backoff policy.""" return 5 @abstractmethod - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: - """ - Override this method to define a pagination strategy. + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: + """Override this method to define a pagination strategy. The value returned from this method is passed to most other methods in this class. Use it to form a request e.g: set headers or query params. @@ -176,22 +159,19 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: - """ - Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" - """ + """Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" """ def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: - """ - Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs. + """Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs. E.g: you might want to define query parameters for paging if next_page_token is not None. """ @@ -199,23 +179,20 @@ def request_params( def request_headers( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. - """ + """Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.""" return {} def request_body_data( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: - """ - Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload. + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str | None: + """Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload. If returns a ready text that it will be sent as is. If returns a dict that it will be converted to a urlencoded form. @@ -227,12 +204,11 @@ def request_body_data( def request_body_json( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: - """ - Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload. + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: + """Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload. At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden. """ @@ -240,12 +216,11 @@ def request_body_json( def request_kwargs( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: - """ - Override to return a mapping of keyword arguments to be used when creating the HTTP request. + """Override to return a mapping of keyword arguments to be used when creating the HTTP request. Any option listed in https://docs.python-requests.org/en/latest/api/#requests.adapters.BaseAdapter.send for can be returned from this method. Note that these options do not conflict with request-level options such as headers, request params, etc.. """ @@ -257,11 +232,10 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: - """ - Parses the raw response object into a list of records. + """Parses the raw response object into a list of records. By default, this returns an iterable containing the input. Override to parse differently. :param response: :param stream_state: @@ -270,9 +244,8 @@ def parse_response( :return: An iterable containing the parsed response """ - def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffStrategy]]]: - """ - Used to initialize Adapter to avoid breaking changes. + def get_backoff_strategy(self) -> BackoffStrategy | list[BackoffStrategy] | None: + """Used to initialize Adapter to avoid breaking changes. If Stream has a `backoff_time` method implementation, we know this stream uses old (pre-HTTPClient) backoff handlers and thus an adapter is needed. Override to provide custom BackoffStrategy @@ -280,12 +253,10 @@ def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffSt """ if hasattr(self, "backoff_time"): return HttpStreamAdapterBackoffStrategy(self) - else: - return None + return None - def get_error_handler(self) -> Optional[ErrorHandler]: - """ - Used to initialize Adapter to avoid breaking changes. + def get_error_handler(self) -> ErrorHandler | None: + """Used to initialize Adapter to avoid breaking changes. If Stream has a `should_retry` method implementation, we know this stream uses old (pre-HTTPClient) error handlers and thus an adapter is needed. Override to provide custom ErrorHandler @@ -299,17 +270,15 @@ def get_error_handler(self) -> Optional[ErrorHandler]: max_time=timedelta(seconds=self.max_time or 0), ) return error_handler - else: - return None + return None @classmethod def _join_url(cls, url_base: str, path: str) -> str: return urljoin(url_base, path) @classmethod - def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: - """ - Parses the raw response object from a failed request into a user-friendly error message. + def parse_response_error_message(cls, response: requests.Response) -> str | None: + """Parses the raw response object from a failed request into a user-friendly error message. By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. :param response: @@ -317,13 +286,13 @@ def parse_response_error_message(cls, response: requests.Response) -> Optional[s """ # default logic to grab error from common fields - def _try_get_error(value: Optional[JsonType]) -> Optional[str]: + def _try_get_error(value: JsonType | None) -> str | None: if isinstance(value, str): return value - elif isinstance(value, list): + if isinstance(value, list): errors_in_value = [_try_get_error(v) for v in value] return ", ".join(v for v in errors_in_value if v is not None) - elif isinstance(value, dict): + if isinstance(value, dict): new_value = ( value.get("message") or value.get("messages") @@ -342,9 +311,8 @@ def _try_get_error(value: Optional[JsonType]) -> Optional[str]: except requests.exceptions.JSONDecodeError: return None - def get_error_display_message(self, exception: BaseException) -> Optional[str]: - """ - Retrieves the user-friendly display message that corresponds to an exception. + def get_error_display_message(self, exception: BaseException) -> str | None: + """Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. The default implementation of this method only handles HTTPErrors by passing the response to self.parse_response_error_message(). @@ -360,9 +328,9 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: # A cursor_field indicates this is an incremental stream which offers better checkpointing than RFR enabled via the cursor if self.cursor_field or not isinstance(self.get_cursor(), ResumableFullRefreshCursor): @@ -396,7 +364,7 @@ def state(self, value: MutableMapping[str, Any]) -> None: cursor.set_initial_state(value) self._state = value - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: # I don't love that this is semi-stateful but not sure what else to do. We don't know exactly what type of cursor to # instantiate when creating the class. We can make a few assumptions like if there is a cursor_field which implies # incremental, but we don't know until runtime if this is a substream. Ideally, a stream should explicitly define @@ -405,8 +373,7 @@ def get_cursor(self) -> Optional[Cursor]: if self.has_multiple_slices and isinstance(self.cursor, ResumableFullRefreshCursor): self.cursor = SubstreamResumableFullRefreshCursor() return self.cursor - else: - return self.cursor + return self.cursor def _read_pages( self, @@ -415,12 +382,12 @@ def _read_pages( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice) @@ -451,12 +418,12 @@ def _read_single_page( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: partition, cursor_slice, remaining_slice = self._extract_slice_fields( stream_slice=stream_slice @@ -480,7 +447,7 @@ def _read_single_page( @staticmethod def _extract_slice_fields( - stream_slice: Optional[Mapping[str, Any]], + stream_slice: Mapping[str, Any] | None, ) -> tuple[Mapping[str, Any], Mapping[str, Any], Mapping[str, Any]]: if not stream_slice: return {}, {}, {} @@ -504,10 +471,10 @@ def _extract_slice_fields( def _fetch_next_page( self, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Tuple[requests.PreparedRequest, requests.Response]: + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> tuple[requests.PreparedRequest, requests.Response]: request, response = self._http_client.send_request( http_method=self.http_method, url=self._join_url( @@ -550,19 +517,14 @@ def _fetch_next_page( return request, response - def get_log_formatter(self) -> Optional[Callable[[requests.Response], Any]]: - """ - - :return Optional[Callable[[requests.Response], Any]]: Function that will be used in logging inside HttpClient - """ + def get_log_formatter(self) -> Callable[[requests.Response], Any] | None: + """:return Optional[Callable[[requests.Response], Any]]: Function that will be used in logging inside HttpClient""" return None class HttpSubStream(HttpStream, ABC): def __init__(self, parent: HttpStream, **kwargs: Any): - """ - :param parent: should be the instance of HttpStream class - """ + """:param parent: should be the instance of HttpStream class""" super().__init__(**kwargs) self.parent = parent self.has_multiple_slices = ( @@ -584,9 +546,9 @@ def __init__(self, parent: HttpStream, **kwargs: Any): def stream_slices( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: # read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does # not support either substreams or RFR, but something that needs to be considered once we do for parent_record in self.parent.read_only_records(stream_state): @@ -611,10 +573,10 @@ def __init__(self, stream: HttpStream): def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: - return self.stream.backoff_time(response_or_exception) # type: ignore # noqa # HttpStream.backoff_time has been deprecated + ) -> float | None: + return self.stream.backoff_time(response_or_exception) # type: ignore # HttpStream.backoff_time has been deprecated @deprecated( @@ -627,12 +589,12 @@ def __init__(self, stream: HttpStream, **kwargs): # type: ignore # noqa super().__init__(**kwargs) def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + self, response_or_exception: requests.Response | Exception | None = None ) -> ErrorResolution: if isinstance(response_or_exception, Exception): return super().interpret_response(response_or_exception) - elif isinstance(response_or_exception, requests.Response): - should_retry = self.stream.should_retry(response_or_exception) # type: ignore # noqa + if isinstance(response_or_exception, requests.Response): + should_retry = self.stream.should_retry(response_or_exception) # type: ignore if should_retry: if response_or_exception.status_code == 429: return ErrorResolution( @@ -645,29 +607,26 @@ def interpret_response( failure_type=FailureType.transient_error, error_message=f"Response status code: {response_or_exception.status_code}. Retrying...", # type: ignore[union-attr] ) - else: - if response_or_exception.ok: # type: ignore # noqa - return ErrorResolution( - response_action=ResponseAction.SUCCESS, - failure_type=None, - error_message=None, - ) - if self.stream.raise_on_http_errors: - return ErrorResolution( - response_action=ResponseAction.FAIL, - failure_type=FailureType.transient_error, - error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.", # type: ignore[union-attr] - ) - else: - return ErrorResolution( - response_action=ResponseAction.IGNORE, - failure_type=FailureType.transient_error, - error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...", # type: ignore[union-attr] - ) - else: - self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + if response_or_exception.ok: # type: ignore + return ErrorResolution( + response_action=ResponseAction.SUCCESS, + failure_type=None, + error_message=None, + ) + if self.stream.raise_on_http_errors: + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.transient_error, + error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.", # type: ignore[union-attr] + ) return ErrorResolution( - response_action=ResponseAction.FAIL, - failure_type=FailureType.system_error, - error_message=f"Received unexpected response type: {type(response_or_exception)}", + response_action=ResponseAction.IGNORE, + failure_type=FailureType.transient_error, + error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...", # type: ignore[union-attr] ) + self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message=f"Received unexpected response type: {type(response_or_exception)}", + ) diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index 704b715c..89375196 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -1,16 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import os import urllib +from collections.abc import Callable, Mapping from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any import orjson import requests import requests_cache +from requests.auth import AuthBase + from airbyte_cdk.models import ( AirbyteMessageSerializer, AirbyteStreamStatus, @@ -48,14 +52,13 @@ as_airbyte_message as stream_status_as_airbyte_message, ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from requests.auth import AuthBase + BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException): - """ - Before the migration to the HttpClient in low-code, the exception raised was + """Before the migration to the HttpClient in low-code, the exception raised was [ReadException](https://github.com/airbytehq/airbyte/blob/8fdd9818ec16e653ba3dd2b167a74b7c07459861/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py#L566). This has been moved to a AirbyteTracedException. The printing on this is questionable (AirbyteTracedException string representation shows the internal_message and not the message). We have already discussed moving the AirbyteTracedException string representation to @@ -65,7 +68,7 @@ class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException): def __str__(self) -> str: if self.message: return self.message - elif self.internal_message: + if self.internal_message: return self.internal_message return "" @@ -78,15 +81,15 @@ def __init__( self, name: str, logger: logging.Logger, - error_handler: Optional[ErrorHandler] = None, - api_budget: Optional[APIBudget] = None, - session: Optional[Union[requests.Session, requests_cache.CachedSession]] = None, - authenticator: Optional[AuthBase] = None, + error_handler: ErrorHandler | None = None, + api_budget: APIBudget | None = None, + session: requests.Session | requests_cache.CachedSession | None = None, + authenticator: AuthBase | None = None, use_cache: bool = False, - backoff_strategy: Optional[Union[BackoffStrategy, List[BackoffStrategy]]] = None, - error_message_parser: Optional[ErrorMessageParser] = None, + backoff_strategy: BackoffStrategy | list[BackoffStrategy] | None = None, + error_message_parser: ErrorMessageParser | None = None, disable_retries: bool = False, - message_repository: Optional[MessageRepository] = None, + message_repository: MessageRepository | None = None, ): self._name = name self._api_budget: APIBudget = api_budget or APIBudget(policies=[]) @@ -113,21 +116,19 @@ def __init__( else: self._backoff_strategies = [DefaultBackoffStrategy()] self._error_message_parser = error_message_parser or JsonErrorMessageParser() - self._request_attempt_count: Dict[requests.PreparedRequest, int] = {} + self._request_attempt_count: dict[requests.PreparedRequest, int] = {} self._disable_retries = disable_retries self._message_repository = message_repository @property def cache_filename(self) -> str: - """ - Override if needed. Return the name of cache file + """Override if needed. Return the name of cache file Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. """ return f"{self._name}.sqlite" def _request_session(self) -> requests.Session: - """ - Session factory based on use_cache property and call rate limits (api_budget parameter) + """Session factory based on use_cache property and call rate limits (api_budget parameter) :return: instance of request-based session """ if self._use_cache: @@ -141,21 +142,15 @@ def _request_session(self) -> requests.Session: return CachedLimiterSession( sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True ) # type: ignore # there are no typeshed stubs for requests_cache - else: - return LimiterSession(api_budget=self._api_budget) + return LimiterSession(api_budget=self._api_budget) def clear_cache(self) -> None: - """ - Clear cached requests for current session, can be called any time - """ + """Clear cached requests for current session, can be called any time""" if isinstance(self._session, requests_cache.CachedSession): self._session.cache.clear() # type: ignore # cache.clear is not typed - def _dedupe_query_params( - self, url: str, params: Optional[Mapping[str, str]] - ) -> Mapping[str, str]: - """ - Remove query parameters from params mapping if they are already encoded in the URL. + def _dedupe_query_params(self, url: str, params: Mapping[str, str] | None) -> Mapping[str, str]: + """Remove query parameters from params mapping if they are already encoded in the URL. :param url: URL with :param params: :return: @@ -166,7 +161,7 @@ def _dedupe_query_params( query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()} duplicate_keys_with_same_value = { - k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k]) + k for k in query_dict if str(params.get(k)) == str(query_dict[k]) } return {k: v for k, v in params.items() if k not in duplicate_keys_with_same_value} @@ -175,10 +170,10 @@ def _create_prepared_request( http_method: str, url: str, dedupe_query_params: bool = False, - headers: Optional[Mapping[str, str]] = None, - params: Optional[Mapping[str, str]] = None, - json: Optional[Mapping[str, Any]] = None, - data: Optional[Union[str, Mapping[str, Any]]] = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, str] | None = None, + json: Mapping[str, Any] | None = None, + data: str | Mapping[str, Any] | None = None, ) -> requests.PreparedRequest: if dedupe_query_params: query_params = self._dedupe_query_params(url, params) @@ -190,7 +185,7 @@ def _create_prepared_request( raise RequestBodyException( "At the same time only one of the 'request_body_data' and 'request_body_json' functions can return data" ) - elif json: + if json: args["json"] = json elif data: args["data"] = data @@ -202,9 +197,7 @@ def _create_prepared_request( @property def _max_retries(self) -> int: - """ - Determines the max retries based on the provided error handler. - """ + """Determines the max retries based on the provided error handler.""" max_retries = None if self._disable_retries: max_retries = 0 @@ -214,9 +207,7 @@ def _max_retries(self) -> int: @property def _max_time(self) -> int: - """ - Determines the max time based on the provided error handler. - """ + """Determines the max time based on the provided error handler.""" return ( self._error_handler.max_time if self._error_handler.max_time is not None @@ -227,11 +218,10 @@ def _send_with_retry( self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any], - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, ) -> requests.Response: - """ - Sends a request with retry logic. + """Sends a request with retry logic. Args: request (requests.PreparedRequest): The prepared HTTP request to send. @@ -240,7 +230,6 @@ def _send_with_retry( Returns: requests.Response: The HTTP response received from the server after retries. """ - max_retries = self._max_retries max_tries = max(0, max_retries) + 1 max_time = self._max_time @@ -266,8 +255,8 @@ def _send( self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any], - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, ) -> requests.Response: if request not in self._request_attempt_count: self._request_attempt_count[request] = 1 @@ -281,8 +270,8 @@ def _send( extra={"headers": request.headers, "url": request.url, "request_body": request.body}, ) - response: Optional[requests.Response] = None - exc: Optional[requests.RequestException] = None + response: requests.Response | None = None + exc: requests.RequestException | None = None try: response = self._session.send(request, **request_kwargs) @@ -335,11 +324,11 @@ def _send( def _handle_error_resolution( self, - response: Optional[requests.Response], - exc: Optional[requests.RequestException], + response: requests.Response | None, + exc: requests.RequestException | None, request: requests.PreparedRequest, error_resolution: ErrorResolution, - exit_on_rate_limit: Optional[bool] = False, + exit_on_rate_limit: bool | None = False, ) -> None: # Emit stream status RUNNING with the reason RATE_LIMITED to log that the rate limit has been reached if error_resolution.response_action == ResponseAction.RATE_LIMITED: @@ -373,7 +362,7 @@ def _handle_error_resolution( failure_type=error_resolution.failure_type, ) - elif error_resolution.response_action == ResponseAction.IGNORE: + if error_resolution.response_action == ResponseAction.IGNORE: if response is not None: log_message = f"Ignoring response for '{request.method}' request to '{request.url}' with response code '{response.status_code}'" else: @@ -413,7 +402,7 @@ def _handle_error_resolution( error_message=error_message, ) - elif retry_endlessly: + if retry_endlessly: raise RateLimitBackoffException( request=request, response=response or exc, error_message=error_message ) @@ -440,18 +429,15 @@ def send_request( http_method: str, url: str, request_kwargs: Mapping[str, Any], - headers: Optional[Mapping[str, str]] = None, - params: Optional[Mapping[str, str]] = None, - json: Optional[Mapping[str, Any]] = None, - data: Optional[Union[str, Mapping[str, Any]]] = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, str] | None = None, + json: Mapping[str, Any] | None = None, + data: str | Mapping[str, Any] | None = None, dedupe_query_params: bool = False, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, - ) -> Tuple[requests.PreparedRequest, requests.Response]: - """ - Prepares and sends request and return request and response objects. - """ - + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, + ) -> tuple[requests.PreparedRequest, requests.Response]: + """Prepares and sends request and return request and response objects.""" request: requests.PreparedRequest = self._create_prepared_request( http_method=http_method, url=url, diff --git a/airbyte_cdk/sources/streams/http/rate_limiting.py b/airbyte_cdk/sources/streams/http/rate_limiting.py index 926a7ad5..e34555ba 100644 --- a/airbyte_cdk/sources/streams/http/rate_limiting.py +++ b/airbyte_cdk/sources/streams/http/rate_limiting.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import sys import time -from typing import Any, Callable, Mapping, Optional +from collections.abc import Callable, Mapping +from typing import Any import backoff from requests import PreparedRequest, RequestException, Response, codes, exceptions @@ -16,6 +18,7 @@ UserDefinedBackoffException, ) + TRANSIENT_EXCEPTIONS = ( DefaultBackoffException, exceptions.ConnectTimeout, @@ -31,7 +34,7 @@ def default_backoff_handler( - max_tries: Optional[int], factor: float, max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, factor: float, max_time: int | None = None, **kwargs: Any ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -40,7 +43,7 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) def should_give_up(exc: Exception) -> bool: @@ -72,7 +75,7 @@ def should_give_up(exc: Exception) -> bool: def http_client_default_backoff_handler( - max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, max_time: int | None = None, **kwargs: Any ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -81,7 +84,7 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) def should_give_up(exc: Exception) -> bool: @@ -101,7 +104,7 @@ def should_give_up(exc: Exception) -> bool: def user_defined_backoff_handler( - max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, max_time: int | None = None, **kwargs: Any ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -146,7 +149,7 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) return backoff.on_exception( # type: ignore # Decorator function returns a function with a different signature than the input function, so mypy can't infer the type of the returned function diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 7942aa36..37dc0266 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -1,31 +1,33 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import abstractmethod +from collections.abc import Mapping, MutableMapping from json import JSONDecodeError -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any import backoff import pendulum import requests +from requests.auth import AuthBase + +from ..exceptions import DefaultBackoffException from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.http_logger import format_http_message from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets -from requests.auth import AuthBase -from ..exceptions import DefaultBackoffException logger = logging.getLogger("airbyte") _NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() class AbstractOauth2Authenticator(AuthBase): - """ - Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator + """Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface. """ @@ -34,12 +36,11 @@ class AbstractOauth2Authenticator(AuthBase): def __init__( self, - refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ) -> None: - """ - If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, + """If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException. """ self._refresh_token_error_status_codes = refresh_token_error_status_codes @@ -69,8 +70,7 @@ def token_has_expired(self) -> bool: return pendulum.now() > self.get_token_expiry_date() # type: ignore # this is always a bool despite what mypy thinks def build_refresh_request_body(self) -> Mapping[str, Any]: - """ - Returns the request body to set on the refresh request + """Returns the request body to set on the refresh request Override to define additional parameters """ @@ -135,10 +135,9 @@ def _get_refresh_access_token_response(self) -> Any: add_to_secrets(access_key) self._log_response(response) return response_json - else: - # log the response even if the request failed for troubleshooting purposes - self._log_response(response) - response.raise_for_status() + # log the response even if the request failed for troubleshooting purposes + self._log_response(response) + response.raise_for_status() except requests.exceptions.RequestException as e: if e.response is not None: if e.response.status_code == 429 or e.response.status_code >= 500: @@ -152,9 +151,8 @@ def _get_refresh_access_token_response(self) -> Any: except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: - """ - Returns the refresh token and its expiration datetime + def refresh_access_token(self) -> tuple[str, str | int]: + """Returns the refresh token and its expiration datetime :return: a tuple of (access_token, token_lifespan) """ @@ -164,36 +162,27 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: self.get_expires_in_name() ] - def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateTime: - """ - Return the expiration datetime of the refresh token + def _parse_token_expiration_date(self, value: str | int) -> pendulum.DateTime: + """Return the expiration datetime of the refresh token :return: expiration datetime """ - if self.token_expiry_is_time_of_expiration: if not self.token_expiry_date_format: raise ValueError( f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." ) return pendulum.from_format(str(value), self.token_expiry_date_format) - else: - return pendulum.now().add(seconds=int(float(value))) + return pendulum.now().add(seconds=int(float(value))) @property def token_expiry_is_time_of_expiration(self) -> bool: - """ - Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. - """ - + """Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.""" return False @property - def token_expiry_date_format(self) -> Optional[str]: - """ - Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires - """ - + def token_expiry_date_format(self) -> str | None: + """Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires""" return None @abstractmethod @@ -209,11 +198,11 @@ def get_client_secret(self) -> str: """The client secret to authenticate""" @abstractmethod - def get_refresh_token(self) -> Optional[str]: + def get_refresh_token(self) -> str | None: """The token used to refresh the access token when it expires""" @abstractmethod - def get_scopes(self) -> List[str]: + def get_scopes(self) -> list[str]: """List of requested scopes""" @abstractmethod @@ -221,7 +210,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime: """Expiration date of the access token""" @abstractmethod - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: str | int) -> None: """Setter for access token expiration date""" @abstractmethod @@ -251,10 +240,8 @@ def access_token(self, value: str) -> str: """Setter for the access token""" @property - def _message_repository(self) -> Optional[MessageRepository]: - """ - The implementation can define a message_repository if it wants debugging logs for HTTP requests - """ + def _message_repository(self) -> MessageRepository | None: + """The implementation can define a message_repository if it wants debugging logs for HTTP requests""" return _NOOP_MESSAGE_REPOSITORY def _log_response(self, response: requests.Response) -> None: diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py index db59600d..1685439b 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from abc import abstractmethod -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from requests.auth import AuthBase diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 4ae84048..c107443c 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import Any import dpath import pendulum + from airbyte_cdk.config_observation import ( create_connector_config_control_message, emit_configuration_as_airbyte_control_message, @@ -17,8 +20,7 @@ class Oauth2Authenticator(AbstractOauth2Authenticator): - """ - Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials. + """Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials. The generated access token is attached to each request via the Authorization header. If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage. """ @@ -29,7 +31,7 @@ def __init__( client_id: str, client_secret: str, refresh_token: str, - scopes: List[str] = None, + scopes: list[str] = None, token_expiry_date: pendulum.DateTime = None, token_expiry_date_format: str = None, access_token_name: str = "access_token", @@ -37,9 +39,9 @@ def __init__( refresh_request_body: Mapping[str, Any] = None, grant_type: str = "refresh_token", token_expiry_is_time_of_expiration: bool = False, - refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ): self._token_refresh_endpoint = token_refresh_endpoint self._client_secret = client_secret @@ -89,7 +91,7 @@ def get_grant_type(self) -> str: def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date - def set_token_expiry_date(self, value: Union[str, int]): + def set_token_expiry_date(self, value: str | int): self._token_expiry_date = self._parse_token_expiration_date(value) @property @@ -97,7 +99,7 @@ def token_expiry_is_time_of_expiration(self) -> bool: return self._token_expiry_is_time_of_expiration @property - def token_expiry_date_format(self) -> Optional[str]: + def token_expiry_date_format(self) -> str | None: return self._token_expiry_date_format @property @@ -110,8 +112,7 @@ def access_token(self, value: str): class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator): - """ - Authenticator that should be used for API implementing single use refresh tokens: + """Authenticator that should be used for API implementing single use refresh tokens: when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow. This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation. By default, this authenticator expects a connector config with a "credentials" field with the following nested fields: client_id, @@ -123,42 +124,41 @@ def __init__( self, connector_config: Mapping[str, Any], token_refresh_endpoint: str, - scopes: List[str] = None, + scopes: list[str] = None, access_token_name: str = "access_token", expires_in_name: str = "expires_in", refresh_token_name: str = "refresh_token", refresh_request_body: Mapping[str, Any] = None, grant_type: str = "refresh_token", - client_id: Optional[str] = None, - client_secret: Optional[str] = None, + client_id: str | None = None, + client_secret: str | None = None, access_token_config_path: Sequence[str] = ("credentials", "access_token"), refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"), token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"), - token_expiry_date_format: Optional[str] = None, + token_expiry_date_format: str | None = None, message_repository: MessageRepository = NoopMessageRepository(), token_expiry_is_time_of_expiration: bool = False, - refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ): - """ - Args: - connector_config (Mapping[str, Any]): The full connector configuration - token_refresh_endpoint (str): Full URL to the token refresh endpoint - scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None. - access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token". - expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". - refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". - refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None. - grant_type (str, optional): OAuth grant type. Defaults to "refresh_token". - client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object. - client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object. - access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token"). - refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token"). - token_expiry_date_config_path (Sequence[str]): Dpath to the token_expiry_date field in the connector configuration. Defaults to ("credentials", "token_expiry_date"). - token_expiry_date_format (Optional[str]): Date format of the token expiry date field (set by expires_in_name). If not specified the token expiry date is interpreted as number of seconds until expiration. - token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration - message_repository (MessageRepository): the message repository used to emit logs on HTTP requests and control message on config update + """Args: + connector_config (Mapping[str, Any]): The full connector configuration + token_refresh_endpoint (str): Full URL to the token refresh endpoint + scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None. + access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token". + expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". + refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". + refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None. + grant_type (str, optional): OAuth grant type. Defaults to "refresh_token". + client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object. + client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object. + access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token"). + refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token"). + token_expiry_date_config_path (Sequence[str]): Dpath to the token_expiry_date field in the connector configuration. Defaults to ("credentials", "token_expiry_date"). + token_expiry_date_format (Optional[str]): Date format of the token expiry date field (set by expires_in_name). If not specified the token expiry date is interpreted as number of seconds until expiration. + token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration + message_repository (MessageRepository): the message repository used to emit logs on HTTP requests and control message on config update """ self._client_id = ( client_id @@ -239,8 +239,7 @@ def get_new_token_expiry_date( ) -> pendulum.DateTime: if token_expiry_date_format: return pendulum.from_format(access_token_expires_in, token_expiry_date_format) - else: - return pendulum.now("UTC").add(seconds=int(access_token_expires_in)) + return pendulum.now("UTC").add(seconds=int(access_token_expires_in)) def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. @@ -269,7 +268,7 @@ def get_access_token(self) -> str: emit_configuration_as_airbyte_control_message(self._connector_config) return self.access_token - def refresh_access_token(self) -> Tuple[str, str, str]: + def refresh_access_token(self) -> tuple[str, str, str]: response_json = self._get_refresh_access_token_response() return ( response_json[self.get_access_token_name()], @@ -279,7 +278,5 @@ def refresh_access_token(self) -> Tuple[str, str, str]: @property def _message_repository(self) -> MessageRepository: - """ - Overriding AbstractOauth2Authenticator._message_repository to allow for HTTP request logs - """ + """Overriding AbstractOauth2Authenticator._message_repository to allow for HTTP request logs""" return self.__message_repository diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py index eec7fd0c..5b03e458 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py @@ -1,10 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import base64 from itertools import cycle -from typing import List from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( AbstractHeaderAuthenticator, @@ -12,8 +12,7 @@ class MultipleTokenAuthenticator(AbstractHeaderAuthenticator): - """ - Builds auth header, based on the list of tokens provided. + """Builds auth header, based on the list of tokens provided. Auth header is changed per each `get_auth_header` call, using each token in cycle. The token is attached to each request via the `auth_header` header. """ @@ -27,7 +26,7 @@ def token(self) -> str: return f"{self._auth_method} {next(self._tokens_iter)}" def __init__( - self, tokens: List[str], auth_method: str = "Bearer", auth_header: str = "Authorization" + self, tokens: list[str], auth_method: str = "Bearer", auth_header: str = "Authorization" ): self._auth_method = auth_method self._auth_header = auth_header @@ -36,8 +35,7 @@ def __init__( class TokenAuthenticator(AbstractHeaderAuthenticator): - """ - Builds auth header, based on the token provided. + """Builds auth header, based on the token provided. The token is attached to each request via the `auth_header` header. """ @@ -56,8 +54,7 @@ def __init__(self, token: str, auth_method: str = "Bearer", auth_header: str = " class BasicHttpAuthenticator(AbstractHeaderAuthenticator): - """ - Builds auth based off the basic authentication scheme as defined by RFC 7617, which transmits credentials as USER ID/password pairs, encoded using bas64 + """Builds auth based off the basic authentication scheme as defined by RFC 7617, which transmits credentials as USER ID/password pairs, encoded using bas64 https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication#basic_authentication_scheme """ @@ -76,7 +73,7 @@ def __init__( auth_method: str = "Basic", auth_header: str = "Authorization", ): - auth_string = f"{username}:{password}".encode("utf8") + auth_string = f"{username}:{password}".encode() b64_encoded = base64.b64encode(auth_string).decode("utf8") self._auth_header = auth_header self._auth_method = auth_method diff --git a/airbyte_cdk/sources/types.py b/airbyte_cdk/sources/types.py index eb13cd08..c72fd8e4 100644 --- a/airbyte_cdk/sources/types.py +++ b/airbyte_cdk/sources/types.py @@ -4,18 +4,20 @@ from __future__ import annotations -from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView +from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView +from typing import Any + # A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2": # "hello"}] returns "hello" -FieldPointer = List[str] +FieldPointer = list[str] Config = Mapping[str, Any] ConnectionDefinition = Mapping[str, Any] StreamState = Mapping[str, Any] class Record(Mapping[str, Any]): - def __init__(self, data: Mapping[str, Any], associated_slice: Optional[StreamSlice]): + def __init__(self, data: Mapping[str, Any], associated_slice: StreamSlice | None): self._data = data self._associated_slice = associated_slice @@ -24,7 +26,7 @@ def data(self) -> Mapping[str, Any]: return self._data @property - def associated_slice(self) -> Optional[StreamSlice]: + def associated_slice(self) -> StreamSlice | None: return self._associated_slice def __repr__(self) -> str: @@ -58,10 +60,9 @@ def __init__( *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any], - extra_fields: Optional[Mapping[str, Any]] = None, + extra_fields: Mapping[str, Any] | None = None, ) -> None: - """ - :param partition: The partition keys representing a unique partition in the stream. + """:param partition: The partition keys representing a unique partition in the stream. :param cursor_slice: The incremental cursor slice keys, such as dates or pagination tokens. :param extra_fields: Additional fields that should not be part of the partition but passed along, such as metadata from the parent stream. """ @@ -123,10 +124,10 @@ def items(self) -> ItemsView[str, Any]: def values(self) -> ValuesView[Any]: return self._stream_slice.values() - def get(self, key: str, default: Any = None) -> Optional[Any]: + def get(self, key: str, default: Any = None) -> Any | None: return self._stream_slice.get(key, default) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, dict): return self._stream_slice == other if isinstance(other, StreamSlice): @@ -134,7 +135,7 @@ def __eq__(self, other: Any) -> bool: return self._partition == other._partition and self._cursor_slice == other._cursor_slice return False - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __json_serializable__(self) -> Any: diff --git a/airbyte_cdk/sources/utils/casing.py b/airbyte_cdk/sources/utils/casing.py index 806e077a..79a273d9 100644 --- a/airbyte_cdk/sources/utils/casing.py +++ b/airbyte_cdk/sources/utils/casing.py @@ -1,7 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import re diff --git a/airbyte_cdk/sources/utils/record_helper.py b/airbyte_cdk/sources/utils/record_helper.py index e45601c2..4f64095a 100644 --- a/airbyte_cdk/sources/utils/record_helper.py +++ b/airbyte_cdk/sources/utils/record_helper.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import time +from collections.abc import Mapping from collections.abc import Mapping as ABCMapping -from typing import Any, Mapping, Optional +from typing import Any from airbyte_cdk.models import ( AirbyteLogMessage, @@ -21,7 +24,7 @@ def stream_data_to_airbyte_message( stream_name: str, data_or_message: StreamData, transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform), - schema: Optional[Mapping[str, Any]] = None, + schema: Mapping[str, Any] | None = None, is_file_transfer_message: bool = False, ) -> AirbyteMessage: if schema is None: diff --git a/airbyte_cdk/sources/utils/schema_helpers.py b/airbyte_cdk/sources/utils/schema_helpers.py index 5b1287ac..0a72cc5d 100644 --- a/airbyte_cdk/sources/utils/schema_helpers.py +++ b/airbyte_cdk/sources/utils/schema_helpers.py @@ -1,25 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import importlib import json import os import pkgutil -from typing import Any, ClassVar, Dict, List, Mapping, MutableMapping, Optional, Tuple +from collections.abc import Mapping, MutableMapping +from typing import Any, ClassVar import jsonref -from airbyte_cdk.models import ConnectorSpecification, FailureType -from airbyte_cdk.utils.traced_exception import AirbyteTracedException from jsonschema import RefResolver, validate from jsonschema.exceptions import ValidationError from pydantic.v1 import BaseModel, Field +from airbyte_cdk.models import ConnectorSpecification, FailureType +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + class JsonFileLoader: - """ - Custom json file loader to resolve references to resources located in "shared" directory. + """Custom json file loader to resolve references to resources located in "shared" directory. We need this for compatability with existing schemas cause all of them have references pointing to shared_schema.json file instead of shared/shared_schema.json """ @@ -28,19 +29,17 @@ def __init__(self, uri_base: str, shared: str): self.shared = shared self.uri_base = uri_base - def __call__(self, uri: str) -> Dict[str, Any]: + def __call__(self, uri: str) -> dict[str, Any]: uri = uri.replace(self.uri_base, f"{self.uri_base}/{self.shared}/") with open(uri) as f: data = json.load(f) if isinstance(data, dict): return data - else: - raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}") + raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}") def resolve_ref_links(obj: Any) -> Any: - """ - Scan resolved schema and convert jsonref.JsonRef object to JSON serializable dict. + """Scan resolved schema and convert jsonref.JsonRef object to JSON serializable dict. :param obj - jsonschema object with ref field resolved. :return JSON serializable object with references without external dependencies. @@ -52,17 +51,15 @@ def resolve_ref_links(obj: Any) -> Any: if isinstance(obj, dict): obj.pop("definitions", None) return obj - else: - raise ValueError(f"Expected obj to be a dict. Got {obj}") - elif isinstance(obj, dict): + raise ValueError(f"Expected obj to be a dict. Got {obj}") + if isinstance(obj, dict): return {k: resolve_ref_links(v) for k, v in obj.items()} - elif isinstance(obj, list): + if isinstance(obj, list): return [resolve_ref_links(item) for item in obj] - else: - return obj + return obj -def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> None: +def _expand_refs(schema: Any, ref_resolver: RefResolver | None = None) -> None: """Internal function to iterate over schema and replace all occurrences of $ref with their definitions. Recursive. :param schema: schema that will be patched @@ -81,7 +78,7 @@ def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> Non else: for key, value in schema.items(): _expand_refs(value, ref_resolver=ref_resolver) - elif isinstance(schema, List): + elif isinstance(schema, list): for value in schema: _expand_refs(value, ref_resolver=ref_resolver) @@ -118,9 +115,7 @@ def __init__(self, package_name: str): self.package_name = package_name def get_schema(self, name: str) -> dict[str, Any]: - """ - This method retrieves a JSON schema from the schemas/ folder. - + """This method retrieves a JSON schema from the schemas/ folder. The expected file structure is to have all top-level schemas (corresponding to streams) in the "schemas/" folder, with any shared $refs living inside the "schemas/shared/" folder. For example: @@ -129,11 +124,10 @@ def get_schema(self, name: str) -> dict[str, Any]: schemas/.json # contains a $ref to shared_definition schemas/.json # contains a $ref to shared_definition """ - schema_filename = f"schemas/{name}.json" raw_file = pkgutil.get_data(self.package_name, schema_filename) if not raw_file: - raise IOError(f"Cannot find file {schema_filename}") + raise OSError(f"Cannot find file {schema_filename}") try: raw_schema = json.loads(raw_file) except ValueError as err: @@ -142,13 +136,11 @@ def get_schema(self, name: str) -> dict[str, Any]: return self._resolve_schema_references(raw_schema) def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, Any]: - """ - Resolve links to external references and move it to local "definitions" map. + """Resolve links to external references and move it to local "definitions" map. :param raw_schema jsonschema to lookup for external links. :return JSON serializable object with references without external dependencies. """ - package = importlib.import_module(self.package_name) if package.__file__: base = os.path.dirname(package.__file__) + "/" @@ -160,15 +152,13 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An resolved = resolve_ref_links(resolved) if isinstance(resolved, dict): return resolved - else: - raise ValueError(f"Expected resolved to be a dict. Got {resolved}") + raise ValueError(f"Expected resolved to be a dict. Got {resolved}") def check_config_against_spec_or_exit( config: Mapping[str, Any], spec: ConnectorSpecification ) -> None: - """ - Check config object against spec. In case of spec is invalid, throws + """Check config object against spec. In case of spec is invalid, throws an exception with validation error description. :param config - config loaded from file specified over command line @@ -196,8 +186,7 @@ def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: return super().dict(*args, **kwargs) # type: ignore[no-any-return] def is_limit_reached(self, records_counter: int) -> bool: - """ - Check if record count reached limit set by internal config. + """Check if record count reached limit set by internal config. :param records_counter - number of records already red :return True if limit reached, False otherwise """ @@ -207,9 +196,8 @@ def is_limit_reached(self, records_counter: int) -> bool: return False -def split_config(config: Mapping[str, Any]) -> Tuple[dict[str, Any], InternalConfig]: - """ - Break config map object into 2 instances: first is a dict with user defined +def split_config(config: Mapping[str, Any]) -> tuple[dict[str, Any], InternalConfig]: + """Break config map object into 2 instances: first is a dict with user defined configuration and second is internal config that contains private keys for acceptance test configuration. diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py index ee802a7a..28fc6aa0 100644 --- a/airbyte_cdk/sources/utils/slice_logger.py +++ b/airbyte_cdk/sources/utils/slice_logger.py @@ -1,27 +1,27 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level from airbyte_cdk.models import Type as MessageType class SliceLogger(ABC): - """ - SliceLogger is an interface that allows us to log slices of data in a uniform way. + """SliceLogger is an interface that allows us to log slices of data in a uniform way. It is responsible for determining whether or not a slice should be logged and for creating the log message. """ SLICE_LOG_PREFIX = "slice:" - def create_slice_log_message(self, _slice: Optional[Mapping[str, Any]]) -> AirbyteMessage: - """ - Mapping is an interface that can be implemented in various ways. However, json.dumps will just do a `str()` if + def create_slice_log_message(self, _slice: Mapping[str, Any] | None) -> AirbyteMessage: + """Mapping is an interface that can be implemented in various ways. However, json.dumps will just do a `str()` if the slice is a class implementing Mapping. Therefore, we want to cast this as a dict before passing this to json.dump """ printable_slice = dict(_slice) if _slice else _slice @@ -35,18 +35,14 @@ def create_slice_log_message(self, _slice: Optional[Mapping[str, Any]]) -> Airby @abstractmethod def should_log_slice_message(self, logger: logging.Logger) -> bool: - """ - - :param logger: + """:param logger: :return: """ class DebugSliceLogger(SliceLogger): def should_log_slice_message(self, logger: logging.Logger) -> bool: - """ - - :param logger: + """:param logger: :return: """ return logger.isEnabledFor(logging.DEBUG) diff --git a/airbyte_cdk/sources/utils/transform.py b/airbyte_cdk/sources/utils/transform.py index ef52c5fd..d3d7e1d4 100644 --- a/airbyte_cdk/sources/utils/transform.py +++ b/airbyte_cdk/sources/utils/transform.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Callable, Mapping from distutils.util import strtobool from enum import Flag, auto -from typing import Any, Callable, Dict, Mapping, Optional +from typing import Any from jsonschema import Draft7Validator, ValidationError, validators + json_to_python_simple = { "string": str, "number": float, @@ -16,18 +19,17 @@ "boolean": bool, "null": type(None), } -json_to_python = {**json_to_python_simple, **{"object": dict, "array": list}} +json_to_python = {**json_to_python_simple, "object": dict, "array": list} python_to_json = {v: k for k, v in json_to_python.items()} logger = logging.getLogger("airbyte") class TransformConfig(Flag): - """ - TypeTransformer class config. Configs can be combined using bitwise or operator e.g. - ``` - TransformConfig.DefaultSchemaNormalization | TransformConfig.CustomSchemaNormalization - ``` + """TypeTransformer class config. Configs can be combined using bitwise or operator e.g. + ``` + TransformConfig.DefaultSchemaNormalization | TransformConfig.CustomSchemaNormalization + ``` """ # No action taken, default behaviour. Cannot be combined with any other options. @@ -42,15 +44,12 @@ class TransformConfig(Flag): class TypeTransformer: - """ - Class for transforming object before output. - """ + """Class for transforming object before output.""" - _custom_normalizer: Optional[Callable[[Any, Dict[str, Any]], Any]] = None + _custom_normalizer: Callable[[Any, dict[str, Any]], Any] | None = None def __init__(self, config: TransformConfig): - """ - Initialize TypeTransformer instance. + """Initialize TypeTransformer instance. :param config Transform config that would be applied to object """ if TransformConfig.NoTransform in config and config != TransformConfig.NoTransform: @@ -67,10 +66,9 @@ def __init__(self, config: TransformConfig): ) def registerCustomTransform( - self, normalization_callback: Callable[[Any, Dict[str, Any]], Any] + self, normalization_callback: Callable[[Any, dict[str, Any]], Any] ) -> Callable: - """ - Register custom normalization callback. + """Register custom normalization callback. :param normalization_callback function to be used for value normalization. Takes original value and part type schema. Should return normalized value. See docs/connector-development/cdk-python/schemas.md @@ -84,9 +82,8 @@ def registerCustomTransform( self._custom_normalizer = normalization_callback return normalization_callback - def __normalize(self, original_item: Any, subschema: Dict[str, Any]) -> Any: - """ - Applies different transform function to object's field according to config. + def __normalize(self, original_item: Any, subschema: dict[str, Any]) -> Any: + """Applies different transform function to object's field according to config. :param original_item original value of field. :param subschema part of the jsonschema containing field type/format data. :return Final field value. @@ -99,9 +96,8 @@ def __normalize(self, original_item: Any, subschema: Dict[str, Any]) -> Any: return original_item @staticmethod - def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: - """ - Default transform function that is used when TransformConfig.DefaultSchemaNormalization flag set. + def default_convert(original_item: Any, subschema: dict[str, Any]) -> Any: + """Default transform function that is used when TransformConfig.DefaultSchemaNormalization flag set. :param original_item original value of field. :param subschema part of the jsonschema containing field type/format data. :return transformed field value. @@ -122,15 +118,15 @@ def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: try: if target_type == "string": return str(original_item) - elif target_type == "number": + if target_type == "number": return float(original_item) - elif target_type == "integer": + if target_type == "integer": return int(original_item) - elif target_type == "boolean": + if target_type == "boolean": if isinstance(original_item, str): return strtobool(original_item) == 1 return bool(original_item) - elif target_type == "array": + if target_type == "array": item_types = set(subschema.get("items", {}).get("type", set())) if ( item_types.issubset(json_to_python_simple) @@ -142,17 +138,15 @@ def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: return original_item def __get_normalizer(self, schema_key: str, original_validator: Callable): - """ - Traverse through object fields using native jsonschema validator and apply normalization function. + """Traverse through object fields using native jsonschema validator and apply normalization function. :param schema_key related json schema key that currently being validated/normalized. :original_validator: native jsonschema validator callback. """ def normalizator( - validator_instance: Callable, property_value: Any, instance: Any, schema: Dict[str, Any] + validator_instance: Callable, property_value: Any, instance: Any, schema: dict[str, Any] ): - """ - Jsonschema validator callable it uses for validating instance. We + """Jsonschema validator callable it uses for validating instance. We override default Draft7Validator to perform value transformation before validation take place. We do not take any action except logging warn if object does not conform to json schema, just using @@ -189,9 +183,8 @@ def resolve(subschema): return normalizator - def transform(self, record: Dict[str, Any], schema: Mapping[str, Any]): - """ - Normalize and validate according to config. + def transform(self, record: dict[str, Any], schema: Mapping[str, Any]): + """Normalize and validate according to config. :param record: record instance for normalization/transformation. All modification are done by modifying existent object. :param schema: object's jsonschema for normalization. """ @@ -208,4 +201,4 @@ def transform(self, record: Dict[str, Any], schema: Mapping[str, Any]): def get_error_message(self, e: ValidationError) -> str: instance_json_type = python_to_json[type(e.instance)] key_path = "." + ".".join(map(str, e.path)) - return f"Failed to transform value {repr(e.instance)} of type '{instance_json_type}' to '{e.validator_value}', key path: '{key_path}'" + return f"Failed to transform value {e.instance!r} of type '{instance_json_type}' to '{e.validator_value}', key path: '{key_path}'" diff --git a/airbyte_cdk/sources/utils/types.py b/airbyte_cdk/sources/utils/types.py index 9dc5e253..707aedc5 100644 --- a/airbyte_cdk/sources/utils/types.py +++ b/airbyte_cdk/sources/utils/types.py @@ -1,7 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from typing import Union + JsonType = Union[dict[str, "JsonType"], list["JsonType"], str, int, float, bool, None] diff --git a/airbyte_cdk/sql/_util/hashing.py b/airbyte_cdk/sql/_util/hashing.py index 781305c4..a88798c7 100644 --- a/airbyte_cdk/sql/_util/hashing.py +++ b/airbyte_cdk/sql/_util/hashing.py @@ -6,6 +6,7 @@ import hashlib from collections.abc import Mapping + HASH_SEED = "Airbyte:" """Additional seed for randomizing one-way hashed strings.""" diff --git a/airbyte_cdk/sql/_util/name_normalizers.py b/airbyte_cdk/sql/_util/name_normalizers.py index 9311432d..21c21611 100644 --- a/airbyte_cdk/sql/_util/name_normalizers.py +++ b/airbyte_cdk/sql/_util/name_normalizers.py @@ -10,6 +10,7 @@ from airbyte_cdk.sql import exceptions as exc + if TYPE_CHECKING: from collections.abc import Iterable diff --git a/airbyte_cdk/sql/constants.py b/airbyte_cdk/sql/constants.py index 2f7de781..b499d31f 100644 --- a/airbyte_cdk/sql/constants.py +++ b/airbyte_cdk/sql/constants.py @@ -3,6 +3,7 @@ from __future__ import annotations + DEBUG_MODE = False # Set to True to enable additional debug logging. AB_EXTRACTED_AT_COLUMN = "_airbyte_extracted_at" diff --git a/airbyte_cdk/sql/exceptions.py b/airbyte_cdk/sql/exceptions.py index 963dc469..79748605 100644 --- a/airbyte_cdk/sql/exceptions.py +++ b/airbyte_cdk/sql/exceptions.py @@ -44,6 +44,7 @@ from textwrap import indent from typing import Any + NEW_ISSUE_URL = "https://github.com/airbytehq/airbyte/issues/new/choose" DOCS_URL_BASE = "https://https://docs.airbyte.com/" DOCS_URL = f"{DOCS_URL_BASE}/airbyte.html" diff --git a/airbyte_cdk/sql/secrets.py b/airbyte_cdk/sql/secrets.py index bff9e810..8df5b488 100644 --- a/airbyte_cdk/sql/secrets.py +++ b/airbyte_cdk/sql/secrets.py @@ -6,9 +6,11 @@ import json from typing import TYPE_CHECKING, Any -from airbyte_cdk.sql import exceptions as exc from pydantic_core import CoreSchema, core_schema +from airbyte_cdk.sql import exceptions as exc + + if TYPE_CHECKING: from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationInfo from pydantic.json_schema import JsonSchemaValue @@ -95,7 +97,7 @@ def validate( return cls(v) @classmethod - def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder + def __get_pydantic_core_schema__( # Pydantic dunder cls, source_type: Any, # noqa: ANN401 # Must allow `Any` to match Pydantic signature handler: GetCoreSchemaHandler, @@ -106,7 +108,7 @@ def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder ) @classmethod - def __get_pydantic_json_schema__( # noqa: PLW3201 # Pydantic dunder method + def __get_pydantic_json_schema__( # Pydantic dunder method cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: """Return a modified JSON schema for the secret string. diff --git a/airbyte_cdk/sql/shared/catalog_providers.py b/airbyte_cdk/sql/shared/catalog_providers.py index 80713a35..64bac21a 100644 --- a/airbyte_cdk/sql/shared/catalog_providers.py +++ b/airbyte_cdk/sql/shared/catalog_providers.py @@ -14,6 +14,7 @@ from airbyte_cdk.sql import exceptions as exc from airbyte_cdk.sql._util.name_normalizers import LowerCaseNormalizer + if TYPE_CHECKING: from airbyte_cdk.models import ConfiguredAirbyteStream diff --git a/airbyte_cdk/sql/shared/sql_processor.py b/airbyte_cdk/sql/shared/sql_processor.py index dd8cb3e5..d7a39863 100644 --- a/airbyte_cdk/sql/shared/sql_processor.py +++ b/airbyte_cdk/sql/shared/sql_processor.py @@ -13,6 +13,13 @@ import pandas as pd import sqlalchemy import ulid +from pandas import Index +from pydantic import BaseModel, Field +from sqlalchemy import Column, Table, and_, create_engine, insert, null, select, text, update +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError + +from airbyte_protocol_dataclasses.models import AirbyteStateMessage + from airbyte_cdk.sql import exceptions as exc from airbyte_cdk.sql._util.hashing import one_way_hash from airbyte_cdk.sql._util.name_normalizers import LowerCaseNormalizer @@ -24,16 +31,11 @@ ) from airbyte_cdk.sql.secrets import SecretString from airbyte_cdk.sql.types import SQLTypeConverter -from airbyte_protocol_dataclasses.models import AirbyteStateMessage -from pandas import Index -from pydantic import BaseModel, Field -from sqlalchemy import Column, Table, and_, create_engine, insert, null, select, text, update -from sqlalchemy.exc import ProgrammingError, SQLAlchemyError + if TYPE_CHECKING: from collections.abc import Generator - from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.reflection import Inspector @@ -41,6 +43,8 @@ from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.type_api import TypeEngine + from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider + class SQLRuntimeError(Exception): """Raised when an SQL operation fails.""" diff --git a/airbyte_cdk/sql/types.py b/airbyte_cdk/sql/types.py index bb6fa1cb..8893670a 100644 --- a/airbyte_cdk/sql/types.py +++ b/airbyte_cdk/sql/types.py @@ -1,4 +1,4 @@ -# noqa: A005 # Allow shadowing the built-in 'types' module +# Allow shadowing the built-in 'types' module # Copyright (c) 2023 Airbyte, Inc., all rights reserved. """Type conversion methods for SQL Caches.""" @@ -9,6 +9,7 @@ import sqlalchemy + # Compare to documentation here: https://docs.airbyte.com/understanding-airbyte/supported-data-types CONVERSION_MAP = { "string": sqlalchemy.types.VARCHAR, diff --git a/airbyte_cdk/test/catalog_builder.py b/airbyte_cdk/test/catalog_builder.py index b1bf4341..23162e1a 100644 --- a/airbyte_cdk/test/catalog_builder.py +++ b/airbyte_cdk/test/catalog_builder.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import Any, Dict, List, Union, overload +from typing import Any, overload from airbyte_cdk.models import ( ConfiguredAirbyteCatalog, @@ -12,7 +13,7 @@ class ConfiguredAirbyteStreamBuilder: def __init__(self) -> None: - self._stream: Dict[str, Any] = { + self._stream: dict[str, Any] = { "stream": { "name": "any name", "json_schema": {}, @@ -24,20 +25,20 @@ def __init__(self) -> None: "destination_sync_mode": "overwrite", } - def with_name(self, name: str) -> "ConfiguredAirbyteStreamBuilder": + def with_name(self, name: str) -> ConfiguredAirbyteStreamBuilder: self._stream["stream"]["name"] = name # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any] return self - def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder": + def with_sync_mode(self, sync_mode: SyncMode) -> ConfiguredAirbyteStreamBuilder: self._stream["sync_mode"] = sync_mode.name return self - def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder": + def with_primary_key(self, pk: list[list[str]]) -> ConfiguredAirbyteStreamBuilder: self._stream["primary_key"] = pk self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any] return self - def with_json_schema(self, json_schema: Dict[str, Any]) -> "ConfiguredAirbyteStreamBuilder": + def with_json_schema(self, json_schema: dict[str, Any]) -> ConfiguredAirbyteStreamBuilder: self._stream["stream"]["json_schema"] = json_schema return self @@ -47,19 +48,19 @@ def build(self) -> ConfiguredAirbyteStream: class CatalogBuilder: def __init__(self) -> None: - self._streams: List[ConfiguredAirbyteStreamBuilder] = [] + self._streams: list[ConfiguredAirbyteStreamBuilder] = [] @overload - def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder": ... + def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> CatalogBuilder: ... @overload - def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder": ... + def with_stream(self, name: str, sync_mode: SyncMode) -> CatalogBuilder: ... def with_stream( self, - name: Union[str, ConfiguredAirbyteStreamBuilder], - sync_mode: Union[SyncMode, None] = None, - ) -> "CatalogBuilder": + name: str | ConfiguredAirbyteStreamBuilder, + sync_mode: SyncMode | None = None, + ) -> CatalogBuilder: # As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface # with_stream(str, SyncMode) diff --git a/airbyte_cdk/test/entrypoint_wrapper.py b/airbyte_cdk/test/entrypoint_wrapper.py index 5e7a80da..37d4dbae 100644 --- a/airbyte_cdk/test/entrypoint_wrapper.py +++ b/airbyte_cdk/test/entrypoint_wrapper.py @@ -1,7 +1,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -""" -The AirbyteEntrypoint is important because it is a service layer that orchestrate how we execute commands from the +"""The AirbyteEntrypoint is important because it is a service layer that orchestrate how we execute commands from the [common interface](https://docs.airbyte.com/understanding-airbyte/airbyte-protocol#common-interface) through the source Python implementation. There is some logic about which message we send to the platform and when which is relevant for integration testing. Other than that, there are integrations point that are annoying to integrate with using Python code: @@ -14,14 +13,21 @@ * The entrypoint interface relies on file being written on the file system """ +from __future__ import annotations + import json import logging import re import tempfile import traceback +from collections.abc import Mapping from io import StringIO from pathlib import Path -from typing import Any, List, Mapping, Optional, Union +from typing import Any + +from orjson import orjson +from pydantic import ValidationError as V2ValidationError +from serpyco_rs import SchemaValidationError from airbyte_cdk.entrypoint import AirbyteEntrypoint from airbyte_cdk.exception_handler import assemble_uncaught_exception @@ -40,13 +46,10 @@ Type, ) from airbyte_cdk.sources import Source -from orjson import orjson -from pydantic import ValidationError as V2ValidationError -from serpyco_rs import SchemaValidationError class EntrypointOutput: - def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None): + def __init__(self, messages: list[str], uncaught_exception: BaseException | None = None): try: self._messages = [self._parse_message(message) for message in messages] except V2ValidationError as exception: @@ -70,15 +73,15 @@ def _parse_message(message: str) -> AirbyteMessage: ) @property - def records_and_state_messages(self) -> List[AirbyteMessage]: + def records_and_state_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.RECORD, Type.STATE]) @property - def records(self) -> List[AirbyteMessage]: + def records(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.RECORD]) @property - def state_messages(self) -> List[AirbyteMessage]: + def state_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.STATE]) @property @@ -89,19 +92,19 @@ def most_recent_state(self) -> Any: return state_messages[-1].state.stream # type: ignore[union-attr] # state has `stream` @property - def logs(self) -> List[AirbyteMessage]: + def logs(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.LOG]) @property - def trace_messages(self) -> List[AirbyteMessage]: + def trace_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.TRACE]) @property - def analytics_messages(self) -> List[AirbyteMessage]: + def analytics_messages(self) -> list[AirbyteMessage]: return self._get_trace_message_by_trace_type(TraceType.ANALYTICS) @property - def errors(self) -> List[AirbyteMessage]: + def errors(self) -> list[AirbyteMessage]: return self._get_trace_message_by_trace_type(TraceType.ERROR) @property @@ -111,7 +114,7 @@ def catalog(self) -> AirbyteMessage: raise ValueError(f"Expected exactly one catalog but got {len(catalog)}") return catalog[0] - def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]: + def get_stream_statuses(self, stream_name: str) -> list[AirbyteStreamStatus]: status_messages = map( lambda message: message.trace.stream_status.status, # type: ignore filter( @@ -121,10 +124,10 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]: ) return list(status_messages) - def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]: + def _get_message_by_types(self, message_types: list[Type]) -> list[AirbyteMessage]: return [message for message in self._messages if message.type in message_types] - def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]: + def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> list[AirbyteMessage]: return [ message for message in self._get_message_by_types([Type.TRACE]) @@ -143,7 +146,7 @@ def is_not_in_logs(self, pattern: str) -> bool: def _run_command( - source: Source, args: List[str], expecting_exception: bool = False + source: Source, args: list[str], expecting_exception: bool = False ) -> EntrypointOutput: log_capture_buffer = StringIO() stream_handler = logging.StreamHandler(log_capture_buffer) @@ -178,12 +181,10 @@ def discover( config: Mapping[str, Any], expecting_exception: bool = False, ) -> EntrypointOutput: - """ - config must be json serializable + """Config must be json serializable :param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please provide expecting_exception=True so that the test output logs are cleaner """ - with tempfile.TemporaryDirectory() as tmp_directory: tmp_directory_path = Path(tmp_directory) config_file = make_file(tmp_directory_path / "config.json", config) @@ -197,11 +198,10 @@ def read( source: Source, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, expecting_exception: bool = False, ) -> EntrypointOutput: - """ - config and state must be json serializable + """Config and state must be json serializable :param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please provide expecting_exception=True so that the test output logs are cleaner @@ -235,7 +235,7 @@ def read( def make_file( - path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]] + path: Path, file_contents: str | Mapping[str, Any] | list[Mapping[str, Any]] | None ) -> str: if isinstance(file_contents, str): path.write_text(file_contents) diff --git a/airbyte_cdk/test/mock_http/matcher.py b/airbyte_cdk/test/mock_http/matcher.py index d07cec3e..4221e00e 100644 --- a/airbyte_cdk/test/mock_http/matcher.py +++ b/airbyte_cdk/test/mock_http/matcher.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import Any +from __future__ import annotations from airbyte_cdk.test.mock_http.request import HttpRequest @@ -35,7 +35,7 @@ def __str__(self) -> str: f"actual_number_of_matches={self._actual_number_of_matches})" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, HttpRequestMatcher): return self._request_to_match == other._request_to_match return False diff --git a/airbyte_cdk/test/mock_http/mocker.py b/airbyte_cdk/test/mock_http/mocker.py index a62c46a5..106f5f9f 100644 --- a/airbyte_cdk/test/mock_http/mocker.py +++ b/airbyte_cdk/test/mock_http/mocker.py @@ -1,12 +1,14 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import contextlib import functools +from collections.abc import Callable from enum import Enum from types import TracebackType -from typing import Callable, List, Optional, Union import requests_mock + from airbyte_cdk.test.mock_http import HttpRequest, HttpRequestMatcher, HttpResponse @@ -18,8 +20,7 @@ class SupportedHttpMethods(str, Enum): class HttpMocker(contextlib.ContextDecorator): - """ - WARNING 1: This implementation only works if the lib used to perform HTTP requests is `requests`. + """WARNING 1: This implementation only works if the lib used to perform HTTP requests is `requests`. WARNING 2: Given multiple requests that are not mutually exclusive, the request will match the first one. This can happen in scenarios where the same request is added twice (in which case there will always be an exception because we will never match the second @@ -36,17 +37,17 @@ class HttpMocker(contextlib.ContextDecorator): def __init__(self) -> None: self._mocker = requests_mock.Mocker() - self._matchers: List[HttpRequestMatcher] = [] + self._matchers: list[HttpRequestMatcher] = [] - def __enter__(self) -> "HttpMocker": + def __enter__(self) -> HttpMocker: self._mocker.__enter__() return self def __exit__( self, - exc_type: Optional[BaseException], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: BaseException | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self._mocker.__exit__(exc_type, exc_val, exc_tb) @@ -59,7 +60,7 @@ def _mock_request_method( self, method: SupportedHttpMethods, request: HttpRequest, - responses: Union[HttpResponse, List[HttpResponse]], + responses: HttpResponse | list[HttpResponse], ) -> None: if isinstance(responses, HttpResponse): responses = [responses] @@ -82,22 +83,16 @@ def _mock_request_method( ], ) - def get(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: + def get(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.GET, request, responses) - def patch( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def patch(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.PATCH, request, responses) - def post( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def post(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.POST, request, responses) - def delete( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def delete(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.DELETE, request, responses) @staticmethod diff --git a/airbyte_cdk/test/mock_http/request.py b/airbyte_cdk/test/mock_http/request.py index 7209513d..e4950a92 100644 --- a/airbyte_cdk/test/mock_http/request.py +++ b/airbyte_cdk/test/mock_http/request.py @@ -1,9 +1,12 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import json -from typing import Any, List, Mapping, Optional, Union +from collections.abc import Mapping +from typing import Any from urllib.parse import parse_qs, urlencode, urlparse + ANY_QUERY_PARAMS = "any query_parameters" @@ -15,9 +18,9 @@ class HttpRequest: def __init__( self, url: str, - query_params: Optional[Union[str, Mapping[str, Union[str, List[str]]]]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Union[str, bytes, Mapping[str, Any]]] = None, + query_params: str | Mapping[str, str | list[str]] | None = None, + headers: Mapping[str, str] | None = None, + body: str | bytes | Mapping[str, Any] | None = None, ) -> None: self._parsed_url = urlparse(url) self._query_params = query_params @@ -32,14 +35,13 @@ def __init__( self._body = body @staticmethod - def _encode_qs(query_params: Union[str, Mapping[str, Union[str, List[str]]]]) -> str: + def _encode_qs(query_params: str | Mapping[str, str | list[str]]) -> str: if isinstance(query_params, str): return query_params return urlencode(query_params, doseq=True) def matches(self, other: Any) -> bool: - """ - If the body of any request is a Mapping, we compare as Mappings which means that the order is not important. + """If the body of any request is a Mapping, we compare as Mappings which means that the order is not important. If the body is a string, encoding ISO-8859-1 will be assumed Headers only need to be a subset of `other` in order to match """ @@ -65,21 +67,21 @@ def matches(self, other: Any) -> bool: @staticmethod def _to_mapping( - body: Optional[Union[str, bytes, Mapping[str, Any]]], - ) -> Optional[Mapping[str, Any]]: + body: str | bytes | Mapping[str, Any] | None, + ) -> Mapping[str, Any] | None: if isinstance(body, Mapping): return body - elif isinstance(body, bytes): + if isinstance(body, bytes): return json.loads(body.decode()) # type: ignore # assumes return type of Mapping[str, Any] - elif isinstance(body, str): + if isinstance(body, str): return json.loads(body) # type: ignore # assumes return type of Mapping[str, Any] return None @staticmethod - def _to_bytes(body: Optional[Union[str, bytes]]) -> bytes: + def _to_bytes(body: str | bytes | None) -> bytes: if isinstance(body, bytes): return body - elif isinstance(body, str): + if isinstance(body, str): # `ISO-8859-1` is the default encoding used by requests return body.encode("ISO-8859-1") return b"" @@ -92,7 +94,7 @@ def __repr__(self) -> str: f"HttpRequest(request={self._parsed_url}, headers={self._headers}, body={self._body!r})" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, HttpRequest): return ( self._parsed_url == other._parsed_url diff --git a/airbyte_cdk/test/mock_http/response.py b/airbyte_cdk/test/mock_http/response.py index 848be55a..0ce900ab 100644 --- a/airbyte_cdk/test/mock_http/response.py +++ b/airbyte_cdk/test/mock_http/response.py @@ -1,7 +1,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations +from collections.abc import Mapping from types import MappingProxyType -from typing import Mapping class HttpResponse: diff --git a/airbyte_cdk/test/mock_http/response_builder.py b/airbyte_cdk/test/mock_http/response_builder.py index b517343e..8089adf4 100644 --- a/airbyte_cdk/test/mock_http/response_builder.py +++ b/airbyte_cdk/test/mock_http/response_builder.py @@ -1,27 +1,28 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import functools import json from abc import ABC, abstractmethod from pathlib import Path as FilePath -from typing import Any, Dict, List, Optional, Union +from typing import Any from airbyte_cdk.test.mock_http import HttpResponse from airbyte_cdk.test.utils.data import get_unit_test_folder -def _extract(path: List[str], response_template: Dict[str, Any]) -> Any: +def _extract(path: list[str], response_template: dict[str, Any]) -> Any: return functools.reduce(lambda a, b: a[b], path, response_template) -def _replace_value(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: +def _replace_value(dictionary: dict[str, Any], path: list[str], value: Any) -> None: current = dictionary for key in path[:-1]: current = current[key] current[path[-1]] = value -def _write(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: +def _write(dictionary: dict[str, Any], path: list[str], value: Any) -> None: current = dictionary for key in path[:-1]: current = current.setdefault(key, {}) @@ -30,14 +31,14 @@ def _write(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: class Path(ABC): @abstractmethod - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: pass @abstractmethod - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: pass - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: pass @@ -45,13 +46,13 @@ class FieldPath(Path): def __init__(self, field: str): self._path = [field] - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: _write(template, self._path, value) - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: _replace_value(template, self._path, value) - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: return _extract(self._path, template) def __str__(self) -> str: @@ -59,16 +60,16 @@ def __str__(self) -> str: class NestedPath(Path): - def __init__(self, path: List[str]): + def __init__(self, path: list[str]): self._path = path - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: _write(template, self._path, value) - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: _replace_value(template, self._path, value) - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: return _extract(self._path, template) def __str__(self) -> str: @@ -77,7 +78,7 @@ def __str__(self) -> str: class PaginationStrategy(ABC): @abstractmethod - def update(self, response: Dict[str, Any]) -> None: + def update(self, response: dict[str, Any]) -> None: pass @@ -86,16 +87,16 @@ def __init__(self, path: Path, value: Any): self._path = path self._value = value - def update(self, response: Dict[str, Any]) -> None: + def update(self, response: dict[str, Any]) -> None: self._path.update(response, self._value) class RecordBuilder: def __init__( self, - template: Dict[str, Any], - id_path: Optional[Path], - cursor_path: Optional[Union[FieldPath, NestedPath]], + template: dict[str, Any], + id_path: Path | None, + cursor_path: FieldPath | NestedPath | None, ): self._record = template self._id_path = id_path @@ -111,7 +112,7 @@ def _validate_template(self) -> None: for field_name, field_path in paths_to_validate: self._validate_field(field_name, field_path) - def _validate_field(self, field_name: str, path: Optional[Path]) -> None: + def _validate_field(self, field_name: str, path: Path | None) -> None: try: if path and not path.extract(self._record): raise ValueError( @@ -122,19 +123,19 @@ def _validate_field(self, field_name: str, path: Optional[Path]) -> None: f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`" ) from exception - def with_id(self, identifier: Any) -> "RecordBuilder": + def with_id(self, identifier: Any) -> RecordBuilder: self._set_field("id", self._id_path, identifier) return self - def with_cursor(self, cursor_value: Any) -> "RecordBuilder": + def with_cursor(self, cursor_value: Any) -> RecordBuilder: self._set_field("cursor", self._cursor_path, cursor_value) return self - def with_field(self, path: Path, value: Any) -> "RecordBuilder": + def with_field(self, path: Path, value: Any) -> RecordBuilder: path.write(self._record, value) return self - def _set_field(self, field_name: str, path: Optional[Path], value: Any) -> None: + def _set_field(self, field_name: str, path: Path | None, value: Any) -> None: if not path: raise ValueError( f"{field_name}_path was not provided and hence, the record {field_name} can't be modified. Please provide `id_field` while " @@ -142,28 +143,28 @@ def _set_field(self, field_name: str, path: Optional[Path], value: Any) -> None: ) path.update(self._record, value) - def build(self) -> Dict[str, Any]: + def build(self) -> dict[str, Any]: return self._record class HttpResponseBuilder: def __init__( self, - template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - pagination_strategy: Optional[PaginationStrategy], + template: dict[str, Any], + records_path: FieldPath | NestedPath, + pagination_strategy: PaginationStrategy | None, ): self._response = template - self._records: List[RecordBuilder] = [] + self._records: list[RecordBuilder] = [] self._records_path = records_path self._pagination_strategy = pagination_strategy self._status_code = 200 - def with_record(self, record: RecordBuilder) -> "HttpResponseBuilder": + def with_record(self, record: RecordBuilder) -> HttpResponseBuilder: self._records.append(record) return self - def with_pagination(self) -> "HttpResponseBuilder": + def with_pagination(self) -> HttpResponseBuilder: if not self._pagination_strategy: raise ValueError( "`pagination_strategy` was not provided and hence, fields related to the pagination can't be modified. Please provide " @@ -172,7 +173,7 @@ def with_pagination(self) -> "HttpResponseBuilder": self._pagination_strategy.update(self._response) return self - def with_status_code(self, status_code: int) -> "HttpResponseBuilder": + def with_status_code(self, status_code: int) -> HttpResponseBuilder: self._status_code = status_code return self @@ -186,7 +187,7 @@ def _get_unit_test_folder(execution_folder: str) -> FilePath: return get_unit_test_folder(execution_folder) # type: ignore # get_unit_test_folder is known to return a FilePath -def find_template(resource: str, execution_folder: str) -> Dict[str, Any]: +def find_template(resource: str, execution_folder: str) -> dict[str, Any]: response_template_filepath = str( get_unit_test_folder(execution_folder) / "resource" @@ -194,19 +195,17 @@ def find_template(resource: str, execution_folder: str) -> Dict[str, Any]: / "response" / f"{resource}.json" ) - with open(response_template_filepath, "r") as template_file: + with open(response_template_filepath) as template_file: return json.load(template_file) # type: ignore # we assume the dev correctly set up the resource file def create_record_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - record_id_path: Optional[Path] = None, - record_cursor_path: Optional[Union[FieldPath, NestedPath]] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + record_id_path: Path | None = None, + record_cursor_path: FieldPath | NestedPath | None = None, ) -> RecordBuilder: - """ - This will use the first record define at `records_path` as a template for the records. If more records are defined, they will be ignored - """ + """This will use the first record define at `records_path` as a template for the records. If more records are defined, they will be ignored""" try: record_template = records_path.extract(response_template)[0] if not record_template: @@ -222,8 +221,8 @@ def create_record_builder( def create_response_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - pagination_strategy: Optional[PaginationStrategy] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + pagination_strategy: PaginationStrategy | None = None, ) -> HttpResponseBuilder: return HttpResponseBuilder(response_template, records_path, pagination_strategy) diff --git a/airbyte_cdk/test/state_builder.py b/airbyte_cdk/test/state_builder.py index a1315cf4..47f30cda 100644 --- a/airbyte_cdk/test/state_builder.py +++ b/airbyte_cdk/test/state_builder.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import Any, List +from typing import Any from airbyte_cdk.models import ( AirbyteStateBlob, @@ -13,9 +14,9 @@ class StateBuilder: def __init__(self) -> None: - self._state: List[AirbyteStateMessage] = [] + self._state: list[AirbyteStateMessage] = [] - def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": + def with_stream_state(self, stream_name: str, state: Any) -> StateBuilder: self._state.append( AirbyteStateMessage( type=AirbyteStateType.STREAM, @@ -23,11 +24,11 @@ def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": stream_state=state if isinstance(state, AirbyteStateBlob) else AirbyteStateBlob(state), - stream_descriptor=StreamDescriptor(**{"name": stream_name}), + stream_descriptor=StreamDescriptor(name=stream_name), ), ) ) return self - def build(self) -> List[AirbyteStateMessage]: + def build(self) -> list[AirbyteStateMessage]: return self._state diff --git a/airbyte_cdk/test/utils/data.py b/airbyte_cdk/test/utils/data.py index 6aaeb839..a8d96996 100644 --- a/airbyte_cdk/test/utils/data.py +++ b/airbyte_cdk/test/utils/data.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from pydantic import FilePath diff --git a/airbyte_cdk/test/utils/http_mocking.py b/airbyte_cdk/test/utils/http_mocking.py index 7fd1419f..d76c92d0 100644 --- a/airbyte_cdk/test/utils/http_mocking.py +++ b/airbyte_cdk/test/utils/http_mocking.py @@ -1,7 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import re -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from requests_mock import Mocker diff --git a/airbyte_cdk/test/utils/reading.py b/airbyte_cdk/test/utils/reading.py index 2d89cb87..3ba35c2c 100644 --- a/airbyte_cdk/test/utils/reading.py +++ b/airbyte_cdk/test/utils/reading.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk import AbstractSource from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog, SyncMode @@ -18,7 +20,7 @@ def read_records( config: Mapping[str, Any], stream_name: str, sync_mode: SyncMode, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, expecting_exception: bool = False, ) -> EntrypointOutput: """Read records from a stream.""" diff --git a/airbyte_cdk/utils/airbyte_secrets_utils.py b/airbyte_cdk/utils/airbyte_secrets_utils.py index 45279e57..ac9d6ba5 100644 --- a/airbyte_cdk/utils/airbyte_secrets_utils.py +++ b/airbyte_cdk/utils/airbyte_secrets_utils.py @@ -1,18 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, List, Mapping +from collections.abc import Mapping +from typing import Any import dpath -def get_secret_paths(spec: Mapping[str, Any]) -> List[List[str]]: +def get_secret_paths(spec: Mapping[str, Any]) -> list[list[str]]: paths = [] - def traverse_schema(schema_item: Any, path: List[str]) -> None: - """ - schema_item can be any property or value in the originally input jsonschema, depending on how far down the recursion stack we go + def traverse_schema(schema_item: Any, path: list[str]) -> None: + """schema_item can be any property or value in the originally input jsonschema, depending on how far down the recursion stack we go path is the path to that schema item in the original input for example if we have the input {'password': {'type': 'string', 'airbyte_secret': True}} then the arguments will evolve as follows: @@ -27,10 +28,9 @@ def traverse_schema(schema_item: Any, path: List[str]) -> None: elif isinstance(schema_item, list): for i in schema_item: traverse_schema(i, path) - else: - if path[-1] == "airbyte_secret" and schema_item is True: - filtered_path = [p for p in path[:-1] if p not in ["properties", "oneOf"]] - paths.append(filtered_path) + elif path[-1] == "airbyte_secret" and schema_item is True: + filtered_path = [p for p in path[:-1] if p not in ["properties", "oneOf"]] + paths.append(filtered_path) traverse_schema(spec, []) return paths @@ -38,9 +38,8 @@ def traverse_schema(schema_item: Any, path: List[str]) -> None: def get_secrets( connection_specification: Mapping[str, Any], config: Mapping[str, Any] -) -> List[Any]: - """ - Get a list of secret values from the source config based on the source specification +) -> list[Any]: + """Get a list of secret values from the source config based on the source specification :type connection_specification: the connection_specification field of an AirbyteSpecification i.e the JSONSchema definition """ secret_paths = get_secret_paths(connection_specification.get("properties", {})) @@ -55,10 +54,10 @@ def get_secrets( return result -__SECRETS_FROM_CONFIG: List[str] = [] +__SECRETS_FROM_CONFIG: list[str] = [] -def update_secrets(secrets: List[str]) -> None: +def update_secrets(secrets: list[str]) -> None: """Update the list of secrets to be replaced""" global __SECRETS_FROM_CONFIG __SECRETS_FROM_CONFIG = secrets diff --git a/airbyte_cdk/utils/analytics_message.py b/airbyte_cdk/utils/analytics_message.py index 82a07491..ad5a379c 100644 --- a/airbyte_cdk/utils/analytics_message.py +++ b/airbyte_cdk/utils/analytics_message.py @@ -1,7 +1,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import time -from typing import Any, Optional +from typing import Any from airbyte_cdk.models import ( AirbyteAnalyticsTraceMessage, @@ -12,7 +13,7 @@ ) -def create_analytics_message(type: str, value: Optional[Any]) -> AirbyteMessage: +def create_analytics_message(type: str, value: Any | None) -> AirbyteMessage: return AirbyteMessage( type=Type.TRACE, trace=AirbyteTraceMessage( diff --git a/airbyte_cdk/utils/constants.py b/airbyte_cdk/utils/constants.py index 1d6345cb..d471864d 100644 --- a/airbyte_cdk/utils/constants.py +++ b/airbyte_cdk/utils/constants.py @@ -1,5 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + ENV_REQUEST_CACHE_PATH = "REQUEST_CACHE_PATH" diff --git a/airbyte_cdk/utils/datetime_format_inferrer.py b/airbyte_cdk/utils/datetime_format_inferrer.py index 28eaefa3..b69424c6 100644 --- a/airbyte_cdk/utils/datetime_format_inferrer.py +++ b/airbyte_cdk/utils/datetime_format_inferrer.py @@ -1,21 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser class DatetimeFormatInferrer: - """ - This class is used to detect toplevel fields in records that might be datetime values, along with the used format. - """ + """This class is used to detect toplevel fields in records that might be datetime values, along with the used format.""" def __init__(self) -> None: self._parser = DatetimeParser() - self._datetime_candidates: Optional[Dict[str, str]] = None + self._datetime_candidates: dict[str, str] | None = None self._formats = [ "%Y-%m-%d", "%Y-%m-%d %H:%M:%S", @@ -38,7 +37,8 @@ def _can_be_datetime(self, value: Any) -> bool: """Checks if the value can be a datetime. This is the case if the value is a string or an integer between 1_000_000_000 and 2_000_000_000 for seconds or between 1_000_000_000_000 and 2_000_000_000_000 for milliseconds. - This is separate from the format check for performance reasons""" + This is separate from the format check for performance reasons + """ if isinstance(value, (str, int)): try: value_as_int = int(value) @@ -86,9 +86,8 @@ def accumulate(self, record: AirbyteRecordMessage) -> None: """Analyzes the record and updates the internal state of candidate datetime fields""" self._initialize(record) if self._datetime_candidates is None else self._validate(record) - def get_inferred_datetime_formats(self) -> Dict[str, str]: - """ - Returns the list of candidate datetime fields - the keys are the field names and the values are the inferred datetime formats. + def get_inferred_datetime_formats(self) -> dict[str, str]: + """Returns the list of candidate datetime fields - the keys are the field names and the values are the inferred datetime formats. For these fields the format was consistent across all visited records. """ return self._datetime_candidates or {} diff --git a/airbyte_cdk/utils/event_timing.py b/airbyte_cdk/utils/event_timing.py index 447543ec..56f6ba79 100644 --- a/airbyte_cdk/utils/event_timing.py +++ b/airbyte_cdk/utils/event_timing.py @@ -1,13 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import logging import time from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Optional + logger = logging.getLogger("airbyte") @@ -25,18 +26,13 @@ def __init__(self, name): self.stack = [] def start_event(self, name): - """ - Start a new event and push it to the stack. - """ + """Start a new event and push it to the stack.""" self.events[name] = Event(name=name) self.count += 1 self.stack.insert(0, self.events[name]) def finish_event(self): - """ - Finish the current event and pop it from the stack. - """ - + """Finish the current event and pop it from the stack.""" if self.stack: event = self.stack.pop(0) event.finish() @@ -44,9 +40,7 @@ def finish_event(self): logger.warning(f"{self.name} finish_event called without start_event") def report(self, order_by="name"): - """ - :param order_by: 'name' or 'duration' - """ + """:param order_by: 'name' or 'duration'""" if order_by == "name": events = sorted(self.events.values(), key=lambda event: event.name) elif order_by == "duration": @@ -60,7 +54,7 @@ def report(self, order_by="name"): class Event: name: str start: float = field(default_factory=time.perf_counter_ns) - end: Optional[float] = field(default=None) + end: float | None = field(default=None) @property def duration(self) -> float: @@ -78,8 +72,6 @@ def finish(self): @contextmanager def create_timer(name): - """ - Creates a new EventTimer as a context manager to improve code readability. - """ + """Creates a new EventTimer as a context manager to improve code readability.""" a_timer = EventTimer(name) yield a_timer diff --git a/airbyte_cdk/utils/is_cloud_environment.py b/airbyte_cdk/utils/is_cloud_environment.py index 25b1eee8..139cc29a 100644 --- a/airbyte_cdk/utils/is_cloud_environment.py +++ b/airbyte_cdk/utils/is_cloud_environment.py @@ -1,15 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import os + CLOUD_DEPLOYMENT_MODE = "cloud" def is_cloud_environment() -> bool: - """ - Returns True if the connector is running in a cloud environment, False otherwise. + """Returns True if the connector is running in a cloud environment, False otherwise. The function checks the value of the DEPLOYMENT_MODE environment variable which is set by the platform. This function can be used to determine whether stricter security measures should be applied. diff --git a/airbyte_cdk/utils/mapping_helpers.py b/airbyte_cdk/utils/mapping_helpers.py index 469fb5e0..f534ec8b 100644 --- a/airbyte_cdk/utils/mapping_helpers.py +++ b/airbyte_cdk/utils/mapping_helpers.py @@ -1,22 +1,22 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations - -from typing import Any, List, Mapping, Optional, Set, Union +from collections.abc import Mapping +from typing import Any def combine_mappings( - mappings: List[Optional[Union[Mapping[str, Any], str]]], -) -> Union[Mapping[str, Any], str]: - """ - Combine multiple mappings into a single mapping. If any of the mappings are a string, return + mappings: list[Mapping[str, Any] | str | None], +) -> Mapping[str, Any] | str: + """Combine multiple mappings into a single mapping. If any of the mappings are a string, return that string. Raise errors in the following cases: * If there are duplicate keys across mappings * If there are multiple string mappings * If there are multiple mappings containing keys and one of them is a string """ - all_keys: List[Set[str]] = [] + all_keys: list[set[str]] = [] for part in mappings: if part is None: continue diff --git a/airbyte_cdk/utils/message_utils.py b/airbyte_cdk/utils/message_utils.py index 7e740b78..11b036b2 100644 --- a/airbyte_cdk/utils/message_utils.py +++ b/airbyte_cdk/utils/message_utils.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from airbyte_cdk.models import AirbyteMessage, Type from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor diff --git a/airbyte_cdk/utils/oneof_option_config.py b/airbyte_cdk/utils/oneof_option_config.py index 17ebf051..aaacd11e 100644 --- a/airbyte_cdk/utils/oneof_option_config.py +++ b/airbyte_cdk/utils/oneof_option_config.py @@ -1,13 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Dict +from typing import Any class OneOfOptionConfig: - """ - Base class to configure a Pydantic model that's used as a oneOf option in a parent model in a way that's compatible with all Airbyte consumers. + """Base class to configure a Pydantic model that's used as a oneOf option in a parent model in a way that's compatible with all Airbyte consumers. Inherit from this class in the nested Config class in a model and set title and description (these show up in the UI) and discriminator (this is making sure it's marked as required in the schema). @@ -26,7 +26,7 @@ class Config(OneOfOptionConfig): """ @staticmethod - def schema_extra(schema: Dict[str, Any], model: Any) -> None: + def schema_extra(schema: dict[str, Any], model: Any) -> None: if hasattr(model.Config, "description"): schema["description"] = model.Config.description if hasattr(model.Config, "discriminator"): diff --git a/airbyte_cdk/utils/print_buffer.py b/airbyte_cdk/utils/print_buffer.py index ae5a2020..f90611ca 100644 --- a/airbyte_cdk/utils/print_buffer.py +++ b/airbyte_cdk/utils/print_buffer.py @@ -1,16 +1,15 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import sys import time from io import StringIO from threading import RLock from types import TracebackType -from typing import Optional class PrintBuffer: - """ - A class to buffer print statements and flush them at a specified interval. + """A class to buffer print statements and flush them at a specified interval. The PrintBuffer class is designed to capture and buffer output that would normally be printed to the standard output (stdout). This can be useful for @@ -57,7 +56,7 @@ def flush(self) -> None: sys.__stdout__.write(combined_message) # type: ignore[union-attr] self.buffer = StringIO() - def __enter__(self) -> "PrintBuffer": + def __enter__(self) -> PrintBuffer: self.old_stdout, self.old_stderr = sys.stdout, sys.stderr # Used to disable buffering during the pytest session, because it is not compatible with capsys if "pytest" not in str(type(sys.stdout)).lower(): @@ -67,9 +66,9 @@ def __enter__(self) -> "PrintBuffer": def __exit__( self, - exc_type: Optional[BaseException], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: BaseException | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self.flush() sys.stdout, sys.stderr = self.old_stdout, self.old_stderr diff --git a/airbyte_cdk/utils/schema_inferrer.py b/airbyte_cdk/utils/schema_inferrer.py index 65d44369..7f5c0ebd 100644 --- a/airbyte_cdk/utils/schema_inferrer.py +++ b/airbyte_cdk/utils/schema_inferrer.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from collections import defaultdict -from typing import Any, Dict, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any -from airbyte_cdk.models import AirbyteRecordMessage from genson import SchemaBuilder, SchemaNode from genson.schema.strategies.object import Object from genson.schema.strategies.scalar import Number +from airbyte_cdk.models import AirbyteRecordMessage + + # schema keywords _TYPE = "type" _NULL_TYPE = "null" @@ -21,8 +25,7 @@ class NoRequiredObj(Object): - """ - This class has Object behaviour, but it does not generate "required[]" fields + """This class has Object behaviour, but it does not generate "required[]" fields every time it parses object. So we don't add unnecessary extra field. The logic is that even reading all the data from a source, it does not mean that there can be another record added with those fields as @@ -30,15 +33,13 @@ class NoRequiredObj(Object): """ def to_schema(self) -> Mapping[str, Any]: - schema: Dict[str, Any] = super(NoRequiredObj, self).to_schema() + schema: dict[str, Any] = super(NoRequiredObj, self).to_schema() schema.pop("required", None) return schema class IntegerToNumber(Number): - """ - This class has the regular Number behaviour, but it will never emit an integer type. - """ + """This class has the regular Number behaviour, but it will never emit an integer type.""" def __init__(self, node_class: SchemaNode): super().__init__(node_class) @@ -50,21 +51,21 @@ class NoRequiredSchemaBuilder(SchemaBuilder): # This type is inferred from the genson lib, but there is no alias provided for it - creating it here for type safety -InferredSchema = Dict[str, Any] +InferredSchema = dict[str, Any] class SchemaValidationException(Exception): @classmethod def merge_exceptions( - cls, exceptions: List["SchemaValidationException"] - ) -> "SchemaValidationException": + cls, exceptions: list[SchemaValidationException] + ) -> SchemaValidationException: # We assume the schema is the same for all SchemaValidationException return SchemaValidationException( exceptions[0].schema, [x for exception in exceptions for x in exception._validation_errors], ) - def __init__(self, schema: InferredSchema, validation_errors: List[Exception]): + def __init__(self, schema: InferredSchema, validation_errors: list[Exception]): self._schema = schema self._validation_errors = validation_errors @@ -73,13 +74,12 @@ def schema(self) -> InferredSchema: return self._schema @property - def validation_errors(self) -> List[str]: + def validation_errors(self) -> list[str]: return list(map(lambda error: str(error), self._validation_errors)) class SchemaInferrer: - """ - This class is used to infer a JSON schema which fits all the records passed into it + """This class is used to infer a JSON schema which fits all the records passed into it throughout its lifecycle via the accumulate method. Instances of this class are stateful, meaning they build their inferred schemas @@ -87,10 +87,10 @@ class SchemaInferrer: """ - stream_to_builder: Dict[str, SchemaBuilder] + stream_to_builder: dict[str, SchemaBuilder] def __init__( - self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None + self, pk: list[list[str]] | None = None, cursor_field: list[list[str]] | None = None ) -> None: self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder) self._pk = [] if pk is None else pk @@ -103,8 +103,7 @@ def accumulate(self, record: AirbyteRecordMessage) -> None: def _null_type_in_any_of(self, node: InferredSchema) -> bool: if _ANY_OF in node: return {_TYPE: _NULL_TYPE} in node[_ANY_OF] - else: - return False + return False def _remove_type_from_any_of(self, node: InferredSchema) -> None: if _ANY_OF in node: @@ -139,12 +138,10 @@ def _ensure_null_type_on_top(self, node: InferredSchema) -> None: node[_TYPE] = [node[_TYPE], _NULL_TYPE] def _clean(self, node: InferredSchema) -> InferredSchema: - """ - Recursively cleans up a produced schema: + """Recursively cleans up a produced schema: - remove anyOf if one of them is just a null value - remove properties of type "null" """ - if isinstance(node, dict): if _ANY_OF in node: self._clean_any_of(node) @@ -164,8 +161,7 @@ def _clean(self, node: InferredSchema) -> InferredSchema: return node def _add_required_properties(self, node: InferredSchema) -> InferredSchema: - """ - This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every + """This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every node as required. """ # Removing nullable for the root as when we call `_clean`, we make everything nullable @@ -183,11 +179,9 @@ def _add_required_properties(self, node: InferredSchema) -> InferredSchema: return node - def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None: - """ - Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required. - """ - errors: List[Exception] = [] + def _add_fields_as_required(self, node: InferredSchema, composite_key: list[list[str]]) -> None: + """Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required.""" + errors: list[Exception] = [] for path in composite_key: try: @@ -199,11 +193,9 @@ def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List raise SchemaValidationException(node, errors) def _add_field_as_required( - self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None + self, node: InferredSchema, path: list[str], traveled_path: list[str] | None = None ) -> None: - """ - Take a nested key and travel the schema to mark every node as required. - """ + """Take a nested key and travel the schema to mark every node as required.""" self._remove_null_from_type(node) if self._is_leaf(path): return @@ -246,7 +238,7 @@ def _add_field_as_required( traveled_path.append(next_node) self._add_field_as_required(node[_PROPERTIES][next_node], path[1:], traveled_path) - def _is_leaf(self, path: List[str]) -> bool: + def _is_leaf(self, path: list[str]) -> bool: return len(path) == 0 def _remove_null_from_type(self, node: InferredSchema) -> None: @@ -256,10 +248,8 @@ def _remove_null_from_type(self, node: InferredSchema) -> None: if len(node[_TYPE]) == 1: node[_TYPE] = node[_TYPE][0] - def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]: - """ - Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name. - """ + def get_stream_schema(self, stream_name: str) -> InferredSchema | None: + """Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.""" return ( self._add_required_properties( self._clean(self.stream_to_builder[stream_name].to_schema()) diff --git a/airbyte_cdk/utils/spec_schema_transformations.py b/airbyte_cdk/utils/spec_schema_transformations.py index 8d47f83e..f54a24b5 100644 --- a/airbyte_cdk/utils/spec_schema_transformations.py +++ b/airbyte_cdk/utils/spec_schema_transformations.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import re @@ -9,8 +10,7 @@ def resolve_refs(schema: dict) -> dict: - """ - For spec schemas generated using Pydantic models, the resulting JSON schema can contain refs between object + """For spec schemas generated using Pydantic models, the resulting JSON schema can contain refs between object relationships. """ json_schema_ref_resolver = RefResolver.from_schema(schema) diff --git a/airbyte_cdk/utils/stream_status_utils.py b/airbyte_cdk/utils/stream_status_utils.py index 49c07f49..89a7a31b 100644 --- a/airbyte_cdk/utils/stream_status_utils.py +++ b/airbyte_cdk/utils/stream_status_utils.py @@ -1,10 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations from datetime import datetime -from typing import List, Optional, Union from airbyte_cdk.models import ( AirbyteMessage, @@ -20,14 +19,11 @@ def as_airbyte_message( - stream: Union[AirbyteStream, StreamDescriptor], + stream: AirbyteStream | StreamDescriptor, current_status: AirbyteStreamStatus, - reasons: Optional[List[AirbyteStreamStatusReason]] = None, + reasons: list[AirbyteStreamStatusReason] | None = None, ) -> AirbyteMessage: - """ - Builds an AirbyteStreamStatusTraceMessage for the provided stream - """ - + """Builds an AirbyteStreamStatusTraceMessage for the provided stream""" now_millis = datetime.now().timestamp() * 1000.0 trace_message = AirbyteTraceMessage( diff --git a/airbyte_cdk/utils/traced_exception.py b/airbyte_cdk/utils/traced_exception.py index 11f60032..ec1e3de2 100644 --- a/airbyte_cdk/utils/traced_exception.py +++ b/airbyte_cdk/utils/traced_exception.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import time import traceback -from typing import Optional + +from orjson import orjson from airbyte_cdk.models import ( AirbyteConnectionStatus, @@ -18,24 +21,20 @@ ) from airbyte_cdk.models import Type as MessageType from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets -from orjson import orjson class AirbyteTracedException(Exception): - """ - An exception that should be emitted as an AirbyteTraceMessage - """ + """An exception that should be emitted as an AirbyteTraceMessage""" def __init__( self, - internal_message: Optional[str] = None, - message: Optional[str] = None, + internal_message: str | None = None, + message: str | None = None, failure_type: FailureType = FailureType.system_error, - exception: Optional[BaseException] = None, - stream_descriptor: Optional[StreamDescriptor] = None, + exception: BaseException | None = None, + stream_descriptor: StreamDescriptor | None = None, ): - """ - :param internal_message: the internal error that caused the failure + """:param internal_message: the internal error that caused the failure :param message: a user-friendly message that indicates the cause of the error :param failure_type: the type of error :param exception: the exception that caused the error, from which the stack trace should be retrieved @@ -49,10 +48,9 @@ def __init__( super().__init__(internal_message) def as_airbyte_message( - self, stream_descriptor: Optional[StreamDescriptor] = None + self, stream_descriptor: StreamDescriptor | None = None ) -> AirbyteMessage: - """ - Builds an AirbyteTraceMessage from the exception + """Builds an AirbyteTraceMessage from the exception :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many stream_descriptors are defined, the one from `as_airbyte_message` will be discarded. @@ -79,7 +77,7 @@ def as_airbyte_message( return AirbyteMessage(type=MessageType.TRACE, trace=trace_message) - def as_connection_status_message(self) -> Optional[AirbyteMessage]: + def as_connection_status_message(self) -> AirbyteMessage | None: if self.failure_type == FailureType.config_error: return AirbyteMessage( type=MessageType.CONNECTION_STATUS, @@ -90,8 +88,7 @@ def as_connection_status_message(self) -> Optional[AirbyteMessage]: return None def emit_message(self) -> None: - """ - Prints the exception as an AirbyteTraceMessage. + """Prints the exception as an AirbyteTraceMessage. Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint. """ message = orjson.dumps(AirbyteMessageSerializer.dump(self.as_airbyte_message())).decode() @@ -102,12 +99,11 @@ def emit_message(self) -> None: def from_exception( cls, exc: BaseException, - stream_descriptor: Optional[StreamDescriptor] = None, + stream_descriptor: StreamDescriptor | None = None, *args, **kwargs, - ) -> "AirbyteTracedException": # type: ignore # ignoring because of args and kwargs - """ - Helper to create an AirbyteTracedException from an existing exception + ) -> AirbyteTracedException: # type: ignore # ignoring because of args and kwargs + """Helper to create an AirbyteTracedException from an existing exception :param exc: the exception that caused the error :param stream_descriptor: describe the stream from which the exception comes from """ @@ -120,10 +116,9 @@ def from_exception( ) # type: ignore # ignoring because of args and kwargs def as_sanitized_airbyte_message( - self, stream_descriptor: Optional[StreamDescriptor] = None + self, stream_descriptor: StreamDescriptor | None = None ) -> AirbyteMessage: - """ - Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body + """Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many stream_descriptors are defined, the one from `as_sanitized_airbyte_message` will be discarded. diff --git a/bin/generate_component_manifest_files.py b/bin/generate_component_manifest_files.py index 7e9c6835..39c76381 100755 --- a/bin/generate_component_manifest_files.py +++ b/bin/generate_component_manifest_files.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import sys from glob import glob @@ -7,6 +8,7 @@ import anyio import dagger + PYTHON_IMAGE = "python:3.10" LOCAL_YAML_DIR_PATH = "airbyte_cdk/sources/declarative" LOCAL_OUTPUT_DIR_PATH = "airbyte_cdk/sources/declarative/models" diff --git a/docs/generate.py b/docs/generate.py index 58589771..8f4a7444 100644 --- a/docs/generate.py +++ b/docs/generate.py @@ -24,7 +24,6 @@ def run() -> None: """Generate docs for all public modules in the Airbyte CDK and save them to docs/generated.""" - public_modules = [ "airbyte_cdk", ] @@ -49,7 +48,7 @@ def run() -> None: for file_name in files: if not file_name.endswith(".py"): continue - if file_name in ["py.typed"]: + if file_name == "py.typed": continue if file_name.startswith((".", "_")): continue diff --git a/reference_docs/_source/conf.py b/reference_docs/_source/conf.py index c661adb4..61c33095 100644 --- a/reference_docs/_source/conf.py +++ b/reference_docs/_source/conf.py @@ -14,9 +14,12 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + import os import sys + sys.path.insert(0, os.path.abspath("../..")) diff --git a/reference_docs/generate_rst_schema.py b/reference_docs/generate_rst_schema.py index 1a2268ee..bbd35f19 100755 --- a/reference_docs/generate_rst_schema.py +++ b/reference_docs/generate_rst_schema.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import sys from os import path -from typing import Any, Dict +from typing import Any from sphinx.cmd.quickstart import QuickstartRenderer from sphinx.ext.apidoc import get_parser, main, recurse_tree, write_file @@ -12,7 +13,7 @@ from sphinx.util import ensuredir -def write_master_file(templatedir: str, master_name: str, values: Dict, opts: Any): +def write_master_file(templatedir: str, master_name: str, values: dict, opts: Any): template = QuickstartRenderer(templatedir=templatedir) opts.destdir = opts.destdir[: opts.destdir.rfind("/")] write_file(master_name, template.render(f"{templatedir}/master_doc.rst_t", values), opts) @@ -30,8 +31,7 @@ def write_master_file(templatedir: str, master_name: str, values: Dict, opts: An # normalize opts if args.header is None: args.header = rootpath.split(path.sep)[-1] - if args.suffix.startswith("."): - args.suffix = args.suffix[1:] + args.suffix = args.suffix.removeprefix(".") if not path.isdir(rootpath): print(__(f"{rootpath} is not a directory."), file=sys.stderr) sys.exit(1) diff --git a/unit_tests/conftest.py b/unit_tests/conftest.py index 3a21552b..9b52873d 100644 --- a/unit_tests/conftest.py +++ b/unit_tests/conftest.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime @@ -8,7 +9,7 @@ import pytest -@pytest.fixture() +@pytest.fixture def mock_sleep(monkeypatch): with freezegun.freeze_time( datetime.datetime.now(), ignore=["_pytest.runner", "_pytest.terminal"] diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 10bd4513..cd4eafc3 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy import dataclasses @@ -12,6 +13,10 @@ import pytest import requests +from orjson import orjson + +from unit_tests.connector_builder.utils import create_configured_catalog + from airbyte_cdk import connector_builder from airbyte_cdk.connector_builder.connector_builder_handler import ( DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, @@ -49,16 +54,15 @@ Level, StreamDescriptor, SyncMode, + Type, ) -from airbyte_cdk.models import Type from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers import SimpleRetrieverTestReadDecorator from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets, update_secrets -from orjson import orjson -from unit_tests.connector_builder.utils import create_configured_catalog + _stream_name = "stream_with_custom_requester" _stream_primary_key = "id" @@ -292,9 +296,7 @@ def invalid_config_file(tmp_path): def _mocked_send(self, request, **kwargs) -> requests.Response: - """ - Mocks the outbound send operation to provide faster and more reliable responses compared to actual API requests - """ + """Mocks the outbound send operation to provide faster and more reliable responses compared to actual API requests""" response = requests.Response() response.request = request response.status_code = 200 @@ -866,8 +868,7 @@ def _create_429_page_response(response_body): * 10, ) def test_read_source(mock_http_stream): - """ - This test sort of acts as an integration test for the connector builder. + """This test sort of acts as an integration test for the connector builder. Each slice has two pages The first page has two records @@ -1002,13 +1003,11 @@ def test_read_source_single_page_single_slice(mock_http_stream): ) @patch.object(requests.Session, "send", _mocked_send) def test_handle_read_external_requests(deployment_mode, url_base, expected_error): - """ - This test acts like an integration test for the connector builder when it receives Test Read requests. + """This test acts like an integration test for the connector builder when it receives Test Read requests. The scenario being tested is whether requests should be denied if they are done on an unsecure channel or are made to internal endpoints when running on Cloud or OSS deployments """ - limits = TestReadLimits(max_records=100, max_pages_per_slice=1, max_slices=1) catalog = ConfiguredAirbyteCatalog( @@ -1088,13 +1087,11 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error ) @patch.object(requests.Session, "send", _mocked_send) def test_handle_read_external_oauth_request(deployment_mode, token_url, expected_error): - """ - This test acts like an integration test for the connector builder when it receives Test Read requests. + """This test acts like an integration test for the connector builder when it receives Test Read requests. The scenario being tested is whether requests should be denied if they are done on an unsecure channel or are made to internal endpoints when running on Cloud or OSS deployments """ - limits = TestReadLimits(max_records=100, max_pages_per_slice=1, max_slices=1) catalog = ConfiguredAirbyteCatalog( diff --git a/unit_tests/connector_builder/test_message_grouper.py b/unit_tests/connector_builder/test_message_grouper.py index e95f7fcc..3284cf9d 100644 --- a/unit_tests/connector_builder/test_message_grouper.py +++ b/unit_tests/connector_builder/test_message_grouper.py @@ -1,12 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json -from typing import Any, Iterator, List, Mapping +from collections.abc import Iterator, Mapping +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from orjson import orjson + +from unit_tests.connector_builder.utils import create_configured_catalog + from airbyte_cdk.connector_builder.message_grouper import MessageGrouper from airbyte_cdk.connector_builder.models import ( HttpRequest, @@ -29,8 +35,7 @@ StreamDescriptor, ) from airbyte_cdk.models import Type as MessageType -from orjson import orjson -from unit_tests.connector_builder.utils import create_configured_catalog + _NO_PK = [[]] _NO_CURSOR_FIELD = [] @@ -253,9 +258,9 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: ), ] expected_logs = [ - LogMessage(**{"message": "log message before the request", "level": "INFO"}), - LogMessage(**{"message": "log message during the page", "level": "INFO"}), - LogMessage(**{"message": "log message after the response", "level": "INFO"}), + LogMessage(message="log message before the request", level="INFO"), + LogMessage(message="log message during the page", level="INFO"), + LogMessage(message="log message after the response", level="INFO"), ] mock_source = make_mock_source( @@ -704,7 +709,7 @@ def test_read_stream_returns_error_if_stream_does_not_exist() -> None: mock_source.read.side_effect = ValueError("error") mock_source.streams.return_value = [make_mock_stream()] - full_config: Mapping[str, Any] = {**CONFIG, **{"__injected_declarative_manifest": MANIFEST}} + full_config: Mapping[str, Any] = {**CONFIG, "__injected_declarative_manifest": MANIFEST} message_grouper = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) actual_response = message_grouper.get_message_groups( @@ -999,7 +1004,7 @@ def request_response_log_message( ) -def any_request_and_response_with_a_record() -> List[AirbyteMessage]: +def any_request_and_response_with_a_record() -> list[AirbyteMessage]: return [ request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"), record_message("hashiras", {"name": "Shinobu Kocho"}), diff --git a/unit_tests/connector_builder/utils.py b/unit_tests/connector_builder/utils.py index a94a0416..d0bb2fcb 100644 --- a/unit_tests/connector_builder/utils.py +++ b/unit_tests/connector_builder/utils.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer diff --git a/unit_tests/destinations/test_destination.py b/unit_tests/destinations/test_destination.py index ffe1fd37..2aed0a79 100644 --- a/unit_tests/destinations/test_destination.py +++ b/unit_tests/destinations/test_destination.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import argparse import io import json +from collections.abc import Iterable, Mapping from os import PathLike -from typing import Any, Dict, Iterable, List, Mapping, Union +from typing import Any from unittest.mock import ANY import pytest +from orjson import orjson + from airbyte_cdk.destinations import Destination from airbyte_cdk.destinations import destination as destination_module from airbyte_cdk.models import ( @@ -29,7 +33,6 @@ SyncMode, Type, ) -from orjson import orjson @pytest.fixture(name="destination") @@ -53,7 +56,7 @@ class TestArgParsing: ], ) def test_successful_parse( - self, arg_list: List[str], expected_output: Mapping[str, Any], destination: Destination + self, arg_list: list[str], expected_output: Mapping[str, Any], destination: Destination ): parsed_args = vars(destination.parse_args(arg_list)) assert ( @@ -74,57 +77,53 @@ def test_successful_parse( (["check", "path"]), ], ) - def test_failed_parse(self, arg_list: List[str], destination: Destination): + def test_failed_parse(self, arg_list: list[str], destination: Destination): # We use BaseException because it encompasses SystemExit (raised by failed parsing) and other exceptions (raised by additional semantic # checks) with pytest.raises(BaseException): destination.parse_args(arg_list) -def _state(state: Dict[str, Any]) -> AirbyteStateMessage: +def _state(state: dict[str, Any]) -> AirbyteStateMessage: return AirbyteStateMessage(data=state) -def _record(stream: str, data: Dict[str, Any]) -> AirbyteRecordMessage: +def _record(stream: str, data: dict[str, Any]) -> AirbyteRecordMessage: return AirbyteRecordMessage(stream=stream, data=data, emitted_at=0) -def _spec(schema: Dict[str, Any]) -> ConnectorSpecification: +def _spec(schema: dict[str, Any]) -> ConnectorSpecification: return ConnectorSpecification(connectionSpecification=schema) -def write_file(path: PathLike, content: Union[str, Mapping]): +def write_file(path: PathLike, content: str | Mapping): content = json.dumps(content) if isinstance(content, Mapping) else content with open(path, "w") as f: f.write(content) def _wrapped( - msg: Union[ - AirbyteRecordMessage, - AirbyteStateMessage, - AirbyteCatalog, - ConnectorSpecification, - AirbyteConnectionStatus, - ], + msg: AirbyteRecordMessage + | AirbyteStateMessage + | AirbyteCatalog + | ConnectorSpecification + | AirbyteConnectionStatus, ) -> AirbyteMessage: if isinstance(msg, AirbyteRecordMessage): return AirbyteMessage(type=Type.RECORD, record=msg) - elif isinstance(msg, AirbyteStateMessage): + if isinstance(msg, AirbyteStateMessage): return AirbyteMessage(type=Type.STATE, state=msg) - elif isinstance(msg, AirbyteCatalog): + if isinstance(msg, AirbyteCatalog): return AirbyteMessage(type=Type.CATALOG, catalog=msg) - elif isinstance(msg, AirbyteConnectionStatus): + if isinstance(msg, AirbyteConnectionStatus): return AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=msg) - elif isinstance(msg, ConnectorSpecification): + if isinstance(msg, ConnectorSpecification): return AirbyteMessage(type=Type.SPEC, spec=msg) - else: - raise Exception(f"Invalid Airbyte Message: {msg}") + raise Exception(f"Invalid Airbyte Message: {msg}") class OrderedIterableMatcher(Iterable): - """ - A class whose purpose is to verify equality of one iterable object against another + """A class whose purpose is to verify equality of one iterable object against another in an ordered fashion """ @@ -269,7 +268,7 @@ def test_run_write(self, mocker, destination: Destination, tmp_path, monkeypatch "airbyte_cdk.destinations.destination.check_config_against_spec_or_exit" ) # mock input is a record followed by some state messages - mocked_input: List[AirbyteMessage] = [ + mocked_input: list[AirbyteMessage] = [ _wrapped(_record("s1", {"k1": "v1"})), *expected_write_result, ] diff --git a/unit_tests/destinations/vector_db_based/config_test.py b/unit_tests/destinations/vector_db_based/config_test.py index 0eeae37b..16f701fe 100644 --- a/unit_tests/destinations/vector_db_based/config_test.py +++ b/unit_tests/destinations/vector_db_based/config_test.py @@ -1,10 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import Union +from __future__ import annotations import dpath +from pydantic.v1 import BaseModel, Field + from airbyte_cdk.destinations.vector_db_based.config import ( AzureOpenAIEmbeddingConfigModel, CohereEmbeddingConfigModel, @@ -14,7 +15,6 @@ ProcessingConfigModel, ) from airbyte_cdk.utils.spec_schema_transformations import resolve_refs -from pydantic.v1 import BaseModel, Field class IndexingModel(BaseModel): @@ -28,13 +28,13 @@ class IndexingModel(BaseModel): class ConfigModel(BaseModel): indexing: IndexingModel - embedding: Union[ - OpenAIEmbeddingConfigModel, - CohereEmbeddingConfigModel, - FakeEmbeddingConfigModel, - AzureOpenAIEmbeddingConfigModel, - OpenAICompatibleEmbeddingConfigModel, - ] = Field( + embedding: ( + OpenAIEmbeddingConfigModel + | CohereEmbeddingConfigModel + | FakeEmbeddingConfigModel + | AzureOpenAIEmbeddingConfigModel + | OpenAICompatibleEmbeddingConfigModel + ) = Field( ..., title="Embedding", description="Embedding configuration", @@ -56,12 +56,12 @@ class Config: @staticmethod def remove_discriminator(schema: dict) -> None: - """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" + """Pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" dpath.delete(schema, "properties/**/discriminator") @classmethod def schema(cls): - """we're overriding the schema classmethod to enable some post-processing""" + """We're overriding the schema classmethod to enable some post-processing""" schema = super().schema() schema = resolve_refs(schema) cls.remove_discriminator(schema) diff --git a/unit_tests/destinations/vector_db_based/document_processor_test.py b/unit_tests/destinations/vector_db_based/document_processor_test.py index f427f42d..e7f92012 100644 --- a/unit_tests/destinations/vector_db_based/document_processor_test.py +++ b/unit_tests/destinations/vector_db_based/document_processor_test.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any from unittest.mock import MagicMock import pytest + from airbyte_cdk.destinations.vector_db_based.config import ( CodeSplitterConfigModel, FieldNameMappingConfigModel, @@ -500,7 +503,7 @@ def test_text_splitter_check(label, split_config, has_error_message): ], ) def test_rename_metadata_fields( - mappings: Optional[List[FieldNameMappingConfigModel]], + mappings: list[FieldNameMappingConfigModel] | None, fields: Mapping[str, Any], expected_chunk_metadata: Mapping[str, Any], ): @@ -555,7 +558,7 @@ def test_rename_metadata_fields( def test_process_multiple_chunks_with_dedupe_mode( primary_key_value: Mapping[str, Any], stringified_primary_key: str, - primary_key: List[List[str]], + primary_key: list[list[str]], ): processor = initialize_processor() diff --git a/unit_tests/destinations/vector_db_based/embedder_test.py b/unit_tests/destinations/vector_db_based/embedder_test.py index dc0f5378..0a6eebe7 100644 --- a/unit_tests/destinations/vector_db_based/embedder_test.py +++ b/unit_tests/destinations/vector_db_based/embedder_test.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock, call import pytest + from airbyte_cdk.destinations.vector_db_based.config import ( AzureOpenAIEmbeddingConfigModel, CohereEmbeddingConfigModel, @@ -33,25 +35,23 @@ ( ( OpenAIEmbedder, - [OpenAIEmbeddingConfigModel(**{"mode": "openai", "openai_key": "abc"}), 1000], + [OpenAIEmbeddingConfigModel(mode="openai", openai_key="abc"), 1000], OPEN_AI_VECTOR_SIZE, ), ( CohereEmbedder, - [CohereEmbeddingConfigModel(**{"mode": "cohere", "cohere_key": "abc"})], + [CohereEmbeddingConfigModel(mode="cohere", cohere_key="abc")], COHERE_VECTOR_SIZE, ), - (FakeEmbedder, [FakeEmbeddingConfigModel(**{"mode": "fake"})], OPEN_AI_VECTOR_SIZE), + (FakeEmbedder, [FakeEmbeddingConfigModel(mode="fake")], OPEN_AI_VECTOR_SIZE), ( AzureOpenAIEmbedder, [ AzureOpenAIEmbeddingConfigModel( - **{ - "mode": "azure_openai", - "openai_key": "abc", - "api_base": "https://my-resource.openai.azure.com", - "deployment": "my-deployment", - } + mode="azure_openai", + openai_key="abc", + api_base="https://my-resource.openai.azure.com", + deployment="my-deployment", ), 1000, ], @@ -61,13 +61,11 @@ OpenAICompatibleEmbedder, [ OpenAICompatibleEmbeddingConfigModel( - **{ - "mode": "openai_compatible", - "api_key": "abc", - "base_url": "https://my-service.com", - "model_name": "text-embedding-ada-002", - "dimensions": 50, - } + mode="openai_compatible", + api_key="abc", + base_url="https://my-service.com", + model_name="text-embedding-ada-002", + dimensions=50, ) ], 50, @@ -132,7 +130,7 @@ def test_from_field_embedder(field_name, dimensions, metadata, expected_embeddin def test_openai_chunking(): - config = OpenAIEmbeddingConfigModel(**{"mode": "openai", "openai_key": "abc"}) + config = OpenAIEmbeddingConfigModel(mode="openai", openai_key="abc") embedder = OpenAIEmbedder(config, 150) mock_embedding_instance = MagicMock() embedder.embeddings = mock_embedding_instance diff --git a/unit_tests/destinations/vector_db_based/writer_test.py b/unit_tests/destinations/vector_db_based/writer_test.py index b39ce8d3..8a2afd42 100644 --- a/unit_tests/destinations/vector_db_based/writer_test.py +++ b/unit_tests/destinations/vector_db_based/writer_test.py @@ -1,11 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Optional from unittest.mock import ANY, MagicMock, call import pytest + from airbyte_cdk.destinations.vector_db_based import ProcessingConfigModel, Writer from airbyte_cdk.models import ( AirbyteLogMessage, @@ -20,7 +21,7 @@ def _generate_record_message( - index: int, stream: str = "example_stream", namespace: Optional[str] = None + index: int, stream: str = "example_stream", namespace: str | None = None ): return AirbyteMessage( type=Type.RECORD, @@ -36,7 +37,7 @@ def _generate_record_message( BATCH_SIZE = 32 -def generate_stream(name: str = "example_stream", namespace: Optional[str] = None): +def generate_stream(name: str = "example_stream", namespace: str | None = None): return { "stream": { "name": name, @@ -66,9 +67,7 @@ def generate_mock_embedder(): @pytest.mark.parametrize("omit_raw_text", [True, False]) def test_write(omit_raw_text: bool): - """ - Basic test for the write method, batcher and document processor. - """ + """Basic test for the write method, batcher and document processor.""" config_model = ProcessingConfigModel( chunk_overlap=0, chunk_size=1000, metadata_fields=None, text_fields=["column_name"] ) @@ -132,8 +131,7 @@ def test_write(omit_raw_text: bool): def test_write_stream_namespace_split(): - """ - Test separate handling of streams and namespaces in the writer + """Test separate handling of streams and namespaces in the writer generate BATCH_SIZE - 10 records for example_stream, 5 records for example_stream with namespace abc and 10 records for example_stream2 messages are flushed after 32 records or after a state message, so this will trigger 4 calls to the indexer: diff --git a/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py b/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py index 6593416a..f640e225 100644 --- a/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py +++ b/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging -from typing import Any, List, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any from unittest.mock import Mock import freezegun import pytest + from airbyte_cdk.models import ( AirbyteMessage, AirbyteRecordMessage, @@ -44,10 +47,10 @@ def __init__( def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: raise NotImplementedError - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: return [ self.convert_to_concurrent_stream(self._logger, s, Mock()) if is_concurrent else s for s, is_concurrent in self._streams_to_is_concurrent.items() @@ -138,7 +141,7 @@ def _mock_stream(name: str, data=[], available: bool = True): return s -def _configured_catalog(streams: List[Stream]): +def _configured_catalog(streams: list[Stream]): return ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( @@ -155,9 +158,7 @@ def _configured_catalog(streams: List[Stream]): def test_read_nonexistent_concurrent_stream_emit_incomplete_stream_status( mocker, remove_stack_trace, as_stream_status, raise_exception_on_missing_stream ): - """ - Tests that attempting to sync a stream which the source does not return from the `streams` method emits incomplete stream status. - """ + """Tests that attempting to sync a stream which the source does not return from the `streams` method emits incomplete stream status.""" logger = Mock() s1 = _mock_stream("s1", []) diff --git a/unit_tests/sources/conftest.py b/unit_tests/sources/conftest.py index d20d763b..dd0fc772 100644 --- a/unit_tests/sources/conftest.py +++ b/unit_tests/sources/conftest.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import pytest + from airbyte_cdk.models import ( AirbyteMessage, AirbyteStreamStatus, @@ -19,9 +21,7 @@ @pytest.fixture def remove_stack_trace(): def _remove_stack_trace(message: AirbyteMessage) -> AirbyteMessage: - """ - Helper method that removes the stack trace from Airbyte trace messages to make asserting against expected records easier - """ + """Helper method that removes the stack trace from Airbyte trace messages to make asserting against expected records easier""" if message.trace and message.trace.error and message.trace.error.stack_trace: message.trace.error.stack_trace = None return message diff --git a/unit_tests/sources/declarative/async_job/test_integration.py b/unit_tests/sources/declarative/async_job/test_integration.py index be078488..de2303b5 100644 --- a/unit_tests/sources/declarative/async_job/test_integration.py +++ b/unit_tests/sources/declarative/async_job/test_integration.py @@ -1,8 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - +from __future__ import annotations import logging -from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple +from collections.abc import Iterable, Mapping +from typing import Any from unittest import TestCase, mock from airbyte_cdk import ( @@ -28,6 +29,7 @@ from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder from airbyte_cdk.test.entrypoint_wrapper import read + _A_STREAM_NAME = "a_stream_name" _EXTRACTOR_NOT_USED: RecordExtractor = None # type: ignore # the extractor should not be used. If it is the case, there is an issue that needs fixing _NO_LIMIT = 10000 @@ -37,7 +39,7 @@ class MockAsyncJobRepository(AsyncJobRepository): def start(self, stream_slice: StreamSlice) -> AsyncJob: return AsyncJob("a_job_id", StreamSlice(partition={}, cursor_slice={})) - def update_jobs_status(self, jobs: Set[AsyncJob]) -> None: + def update_jobs_status(self, jobs: set[AsyncJob]) -> None: for job in jobs: job.update_status(AsyncJobStatus.COMPLETED) @@ -52,19 +54,19 @@ def delete(self, job: AsyncJob) -> None: class MockSource(AbstractSource): - def __init__(self, stream_slicer: Optional[StreamSlicer] = None) -> None: + def __init__(self, stream_slicer: StreamSlicer | None = None) -> None: self._stream_slicer = SinglePartitionRouter({}) if stream_slicer is None else stream_slicer self._message_repository = NoopMessageRepository() def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: return True, None def spec(self, logger: logging.Logger) -> ConnectorSpecification: return ConnectorSpecification(connectionSpecification={}) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: noop_record_selector = RecordSelector( extractor=_EXTRACTOR_NOT_USED, config={}, @@ -119,9 +121,7 @@ def test_when_read_then_return_records_from_repository(self) -> None: assert len(output.records) == 1 def test_when_read_then_call_stream_slices_only_once(self) -> None: - """ - As generating stream slices is very expensive, we want to ensure that during a read, it is only called once. - """ + """As generating stream slices is very expensive, we want to ensure that during a read, it is only called once.""" output = read( self._source, self._CONFIG, diff --git a/unit_tests/sources/declarative/async_job/test_job.py b/unit_tests/sources/declarative/async_job/test_job.py index 6399433e..4d735631 100644 --- a/unit_tests/sources/declarative/async_job/test_job.py +++ b/unit_tests/sources/declarative/async_job/test_job.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import time from datetime import timedelta @@ -8,6 +9,7 @@ from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.types import StreamSlice + _AN_API_JOB_ID = "an api job id" _ANY_STREAM_SLICE = StreamSlice(partition={}, cursor_slice={}) _A_VERY_BIG_TIMEOUT = timedelta(days=999999999) @@ -25,8 +27,7 @@ def test_given_timer_is_out_when_status_then_return_timed_out(self) -> None: assert job.status() == AsyncJobStatus.TIMED_OUT def test_given_status_is_terminal_when_update_status_then_stop_timer(self) -> None: - """ - This test will become important once we will print stats associated with jobs. As for now, we stop the timer but do not return any + """This test will become important once we will print stats associated with jobs. As for now, we stop the timer but do not return any metrics regarding the timer so it is not useful. """ pass diff --git a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py index af8e84e7..5ff46572 100644 --- a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py +++ b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py @@ -1,14 +1,16 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import logging import sys import threading import time -from typing import Callable, List, Mapping, Optional, Set, Tuple +from collections.abc import Callable, Mapping from unittest import TestCase, mock from unittest.mock import MagicMock, Mock, call import pytest + from airbyte_cdk import AirbyteTracedException, StreamSlice from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.async_job.job import AsyncJob, AsyncJobStatus @@ -21,6 +23,7 @@ from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.http.http_client import MessageRepresentationAirbyteTracedErrors + _ANY_STREAM_SLICE = Mock() _A_STREAM_SLICE = Mock() _ANOTHER_STREAM_SLICE = Mock() @@ -62,11 +65,11 @@ def test_given_only_completed_jobs_when_status_then_return_running(self) -> None def _status_update_per_jobs( - status_update_per_jobs: Mapping[AsyncJob, List[AsyncJobStatus]], + status_update_per_jobs: Mapping[AsyncJob, list[AsyncJobStatus]], ) -> Callable[[set[AsyncJob]], None]: - status_index_by_job = {job: 0 for job in status_update_per_jobs.keys()} + status_index_by_job = dict.fromkeys(status_update_per_jobs.keys(), 0) - def _update_status(jobs: Set[AsyncJob]) -> None: + def _update_status(jobs: set[AsyncJob]) -> None: for job in jobs: status_index = status_index_by_job[job] job.update_status(status_update_per_jobs[job][status_index]) @@ -182,9 +185,9 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None: assert self._job_repository.delete.mock_calls == [call(first_job), call(second_job)] def _orchestrator( - self, slices: List[StreamSlice], job_tracker: Optional[JobTracker] = None + self, slices: list[StreamSlice], job_tracker: JobTracker | None = None ) -> AsyncJobOrchestrator: - job_tracker = job_tracker if job_tracker else JobTracker(_NO_JOB_LIMIT) + job_tracker = job_tracker or JobTracker(_NO_JOB_LIMIT) return AsyncJobOrchestrator( self._job_repository, slices, job_tracker, self._message_repository ) @@ -238,9 +241,7 @@ def test_given_exception_to_break_when_start_job_and_raise_this_exception_and_ab def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_abort_jobs( self, ) -> None: - """ - Since this is a config error, we assume the other jobs will fail for the same reasons. - """ + """Since this is a config error, we assume the other jobs will fail for the same reasons.""" job_tracker = JobTracker(1) self._job_repository.start.side_effect = MessageRepresentationAirbyteTracedErrors( "Can't create job", failure_type=FailureType.config_error @@ -263,9 +264,7 @@ def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_a def test_given_exception_on_single_job_when_create_and_get_completed_partitions_then_return( self, mock_sleep: MagicMock ) -> None: - """ - We added this test because the initial logic of breaking the main loop we implemented (when `self._has_started_a_job and self._running_partitions`) was not enough in the case where there was only one slice and it would fail to start. - """ + """We added this test because the initial logic of breaking the main loop we implemented (when `self._has_started_a_job and self._running_partitions`) was not enough in the case where there was only one slice and it would fail to start.""" orchestrator = self._orchestrator([_A_STREAM_SLICE]) self._job_repository.start.side_effect = ValueError @@ -365,7 +364,7 @@ def _an_async_job(self, job_id: str, stream_slice: StreamSlice) -> AsyncJob: def _accumulate_create_and_get_completed_partitions( self, orchestrator: AsyncJobOrchestrator - ) -> Tuple[List[AsyncPartition], Optional[Exception]]: + ) -> tuple[list[AsyncPartition], Exception | None]: result = [] try: for i in orchestrator.create_and_get_completed_partitions(): diff --git a/unit_tests/sources/declarative/async_job/test_job_tracker.py b/unit_tests/sources/declarative/async_job/test_job_tracker.py index 6d09df16..fd865321 100644 --- a/unit_tests/sources/declarative/async_job/test_job_tracker.py +++ b/unit_tests/sources/declarative/async_job/test_job_tracker.py @@ -1,14 +1,16 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import List from unittest import TestCase import pytest + from airbyte_cdk.sources.declarative.async_job.job_tracker import ( ConcurrentJobLimitReached, JobTracker, ) + _LIMIT = 3 @@ -36,5 +38,5 @@ def test_given_limit_reached_when_add_job_then_limit_is_still_reached(self) -> N with pytest.raises(ConcurrentJobLimitReached): self._tracker.try_to_get_intent() - def _reach_limit(self) -> List[str]: + def _reach_limit(self) -> list[str]: return [self._tracker.try_to_get_intent() for i in range(_LIMIT)] diff --git a/unit_tests/sources/declarative/auth/test_jwt.py b/unit_tests/sources/declarative/auth/test_jwt.py index a26042f7..ee59ae60 100644 --- a/unit_tests/sources/declarative/auth/test_jwt.py +++ b/unit_tests/sources/declarative/auth/test_jwt.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import base64 import logging from datetime import datetime @@ -8,15 +10,15 @@ import freezegun import jwt import pytest + from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator + LOGGER = logging.getLogger(__name__) class TestJwtAuthenticator: - """ - Test class for JWT Authenticator. - """ + """Test class for JWT Authenticator.""" @pytest.mark.parametrize( "algorithm, kid, typ, cty, additional_jwt_headers, expected", @@ -110,7 +112,7 @@ def test_given_overriden_reserverd_properties_get_jwt_payload_throws_error(self) @pytest.mark.parametrize( "base64_encode_secret_key, secret_key, expected", [ - (True, "test", base64.b64encode("test".encode()).decode()), + (True, "test", base64.b64encode(b"test").decode()), (False, "test", "test"), ], ) diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index 4cdfad2f..f440f4c2 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import base64 import logging from unittest.mock import Mock @@ -9,9 +11,11 @@ import pendulum import pytest import requests +from requests import Response + from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets -from requests import Response + LOGGER = logging.getLogger(__name__) @@ -30,14 +34,10 @@ class TestOauth2Authenticator: - """ - Test class for OAuth2Authenticator. - """ + """Test class for OAuth2Authenticator.""" def test_refresh_request_body(self): - """ - Request body should match given configuration. - """ + """Request body should match given configuration.""" scopes = ["scope1", "scope2"] oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -108,9 +108,7 @@ def test_refresh_with_decode_config_params(self): assert body == expected def test_refresh_without_refresh_token(self): - """ - Should work fine for grant_type client_credentials. - """ + """Should work fine for grant_type client_credentials.""" oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", client_id="{{ config['client_id'] }}", @@ -129,9 +127,7 @@ def test_refresh_without_refresh_token(self): assert body == expected def test_error_on_refresh_token_grant_without_refresh_token(self): - """ - Should throw an error if grant_type refresh_token is configured without refresh_token. - """ + """Should throw an error if grant_type refresh_token is configured without refresh_token.""" with pytest.raises(ValueError): DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -166,7 +162,7 @@ def test_refresh_access_token(self, mocker): mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert token == ("access_token", 1000) filtered = filter_secrets("access_token") assert filtered == "****" @@ -270,7 +266,7 @@ def test_refresh_access_token_expire_format( ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token = oauth.get_access_token() - assert "access_token" == token + assert token == "access_token" assert oauth.get_token_expiry_date() == pendulum.parse(next_day) assert message_repository.log_message.call_count == 1 @@ -323,7 +319,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next oauth.get_access_token() else: token = oauth.get_access_token() - assert "access_token" == token + assert token == "access_token" assert oauth.get_token_expiry_date() == pendulum.parse(next_day) def test_error_handling(self, mocker): diff --git a/unit_tests/sources/declarative/auth/test_selective_authenticator.py b/unit_tests/sources/declarative/auth/test_selective_authenticator.py index 55b0a1ed..2e935d59 100644 --- a/unit_tests/sources/declarative/auth/test_selective_authenticator.py +++ b/unit_tests/sources/declarative/auth/test_selective_authenticator.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.auth.selective_authenticator import SelectiveAuthenticator diff --git a/unit_tests/sources/declarative/auth/test_session_token_auth.py b/unit_tests/sources/declarative/auth/test_session_token_auth.py index eda2f36b..abe645e5 100644 --- a/unit_tests/sources/declarative/auth/test_session_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_session_token_auth.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest +from requests.exceptions import HTTPError + from airbyte_cdk.sources.declarative.auth.token import ( LegacySessionTokenAuthenticator, get_new_session_token, ) -from requests.exceptions import HTTPError + parameters = {"hello": "world"} instance_api_url = "https://airbyte.metabaseapp.com/api/" diff --git a/unit_tests/sources/declarative/auth/test_token_auth.py b/unit_tests/sources/declarative/auth/test_token_auth.py index 64b181c4..0f2fd187 100644 --- a/unit_tests/sources/declarative/auth/test_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_token_auth.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import pytest import requests +from requests import Response + from airbyte_cdk.sources.declarative.auth.token import ( ApiKeyAuthenticator, BasicHttpAuthenticator, @@ -16,7 +19,7 @@ RequestOption, RequestOptionType, ) -from requests import Response + LOGGER = logging.getLogger(__name__) @@ -34,9 +37,7 @@ ], ) def test_bearer_token_authenticator(test_name, token, expected_header_value): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_provider = InterpolatedStringTokenProvider( config=config, api_token=token, parameters=parameters ) @@ -48,9 +49,9 @@ def test_bearer_token_authenticator(test_name, token, expected_header_value): prepared_request.headers = {} token_auth(prepared_request) - assert {"Authorization": expected_header_value} == prepared_request.headers - assert {"Authorization": expected_header_value} == header1 - assert {"Authorization": expected_header_value} == header2 + assert prepared_request.headers == {"Authorization": expected_header_value} + assert header1 == {"Authorization": expected_header_value} + assert header2 == {"Authorization": expected_header_value} @pytest.mark.parametrize( @@ -72,9 +73,7 @@ def test_bearer_token_authenticator(test_name, token, expected_header_value): ], ) def test_basic_authenticator(test_name, username, password, expected_header_value): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_auth = BasicHttpAuthenticator( username=username, password=password, config=config, parameters=parameters ) @@ -85,9 +84,9 @@ def test_basic_authenticator(test_name, username, password, expected_header_valu prepared_request.headers = {} token_auth(prepared_request) - assert {"Authorization": expected_header_value} == prepared_request.headers - assert {"Authorization": expected_header_value} == header1 - assert {"Authorization": expected_header_value} == header2 + assert prepared_request.headers == {"Authorization": expected_header_value} + assert header1 == {"Authorization": expected_header_value} + assert header2 == {"Authorization": expected_header_value} @pytest.mark.parametrize( @@ -111,9 +110,7 @@ def test_basic_authenticator(test_name, username, password, expected_header_valu ], ) def test_api_key_authenticator(test_name, header, token, expected_header, expected_header_value): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_provider = InterpolatedStringTokenProvider( config=config, api_token=token, parameters=parameters ) @@ -132,9 +129,9 @@ def test_api_key_authenticator(test_name, header, token, expected_header, expect prepared_request.headers = {} token_auth(prepared_request) - assert {expected_header: expected_header_value} == prepared_request.headers - assert {expected_header: expected_header_value} == header1 - assert {expected_header: expected_header_value} == header2 + assert prepared_request.headers == {expected_header: expected_header_value} + assert header1 == {expected_header: expected_header_value} + assert header2 == {expected_header: expected_header_value} @pytest.mark.parametrize( @@ -232,9 +229,7 @@ def test_api_key_authenticator_inject( inject_type, validation_fn, ): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_provider = InterpolatedStringTokenProvider( config=config, api_token=token, parameters=parameters ) @@ -246,4 +241,4 @@ def test_api_key_authenticator_inject( config=config, parameters=parameters, ) - assert {expected_field_name: expected_field_value} == getattr(token_auth, validation_fn)() + assert getattr(token_auth, validation_fn)() == {expected_field_name: expected_field_value} diff --git a/unit_tests/sources/declarative/auth/test_token_provider.py b/unit_tests/sources/declarative/auth/test_token_provider.py index 684dfbf7..f68bf18d 100644 --- a/unit_tests/sources/declarative/auth/test_token_provider.py +++ b/unit_tests/sources/declarative/auth/test_token_provider.py @@ -1,17 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pendulum import pytest +from isodate import parse_duration + from airbyte_cdk.sources.declarative.auth.token_provider import ( InterpolatedStringTokenProvider, SessionTokenProvider, ) from airbyte_cdk.sources.declarative.exceptions import ReadException -from isodate import parse_duration def create_session_token_provider(): diff --git a/unit_tests/sources/declarative/checks/test_check_stream.py b/unit_tests/sources/declarative/checks/test_check_stream.py index aee429c8..1febe645 100644 --- a/unit_tests/sources/declarative/checks/test_check_stream.py +++ b/unit_tests/sources/declarative/checks/test_check_stream.py @@ -1,16 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from unittest.mock import MagicMock import pytest import requests + from airbyte_cdk.sources.declarative.checks.check_stream import CheckStream from airbyte_cdk.sources.streams.http import HttpStream + logger = logging.getLogger("test") config = dict() @@ -125,7 +129,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.resp_counter = 1 - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def path(self, **kwargs) -> str: diff --git a/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py b/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py index 5858b680..331b8ec4 100644 --- a/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py +++ b/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py @@ -1,8 +1,11 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations -from typing import Any, Mapping, Optional, Type, Union +from collections.abc import Mapping +from typing import Any import pytest + from airbyte_cdk.sources.declarative.concurrency_level import ConcurrencyLevel @@ -32,7 +35,7 @@ ], ) def test_stream_slices( - default_concurrency: Union[int, str], max_concurrency: int, expected_concurrency: int + default_concurrency: int | str, max_concurrency: int, expected_concurrency: int ) -> None: config = {"num_workers": 50} concurrency_level = ConcurrencyLevel( @@ -62,8 +65,8 @@ def test_stream_slices( ) def test_default_concurrency_input_types_and_errors( config: Mapping[str, Any], - expected_concurrency: Optional[int], - expected_error: Optional[Type[Exception]], + expected_concurrency: int | None, + expected_error: type[Exception] | None, ) -> None: concurrency_level = ConcurrencyLevel( default_concurrency="{{ config['num_workers'] or 30 }}", diff --git a/unit_tests/sources/declarative/datetime/test_datetime_parser.py b/unit_tests/sources/declarative/datetime/test_datetime_parser.py index 6cbe59c7..e887cf98 100644 --- a/unit_tests/sources/declarative/datetime/test_datetime_parser.py +++ b/unit_tests/sources/declarative/datetime/test_datetime_parser.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import pytest + from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser diff --git a/unit_tests/sources/declarative/datetime/test_min_max_datetime.py b/unit_tests/sources/declarative/datetime/test_min_max_datetime.py index 848d673b..8dfd48b8 100644 --- a/unit_tests/sources/declarative/datetime/test_min_max_datetime.py +++ b/unit_tests/sources/declarative/datetime/test_min_max_datetime.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import pytest + from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString + date_format = "%Y-%m-%dT%H:%M:%S.%f%z" old_date = "2021-01-01T20:12:19.597854Z" @@ -91,7 +94,7 @@ def test_min_max_datetime(test_name, date, min_date, max_date, expected_date): min_max_date = MinMaxDatetime( datetime=date, min_datetime=min_date, max_datetime=max_date, parameters=parameters ) - actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) + actual_date = min_max_date.get_datetime(config, stream_state=stream_state) assert actual_date == datetime.datetime.strptime(expected_date, date_format) @@ -107,7 +110,7 @@ def test_custom_datetime_format(): max_datetime="{{ stream_state['newer'] }}", parameters={}, ) - actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) + actual_date = min_max_date.get_datetime(config, stream_state=stream_state) assert actual_date == datetime.datetime.strptime( "2022-01-01T20:12:19", "%Y-%m-%dT%H:%M:%S" @@ -125,7 +128,7 @@ def test_format_is_a_number(): max_datetime="{{ stream_state['newer'] }}", parameters={}, ) - actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) + actual_date = min_max_date.get_datetime(config, stream_state=stream_state) assert actual_date == datetime.datetime.strptime("20220101", "%Y%m%d").replace( tzinfo=datetime.timezone.utc diff --git a/unit_tests/sources/declarative/decoders/test_json_decoder.py b/unit_tests/sources/declarative/decoders/test_json_decoder.py index 861b6e27..4089eecb 100644 --- a/unit_tests/sources/declarative/decoders/test_json_decoder.py +++ b/unit_tests/sources/declarative/decoders/test_json_decoder.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import json import os import pytest import requests + from airbyte_cdk import YamlDeclarativeSource from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder, JsonlDecoder diff --git a/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py b/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py index 022482e7..213a30e7 100644 --- a/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py +++ b/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py @@ -1,8 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import pytest import requests + from airbyte_cdk.sources.declarative.decoders import JsonDecoder, PaginationDecoderDecorator diff --git a/unit_tests/sources/declarative/decoders/test_xml_decoder.py b/unit_tests/sources/declarative/decoders/test_xml_decoder.py index c6295cff..6b3c16ec 100644 --- a/unit_tests/sources/declarative/decoders/test_xml_decoder.py +++ b/unit_tests/sources/declarative/decoders/test_xml_decoder.py @@ -1,8 +1,11 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import pytest import requests + from airbyte_cdk.sources.declarative.decoders import XmlDecoder diff --git a/unit_tests/sources/declarative/external_component.py b/unit_tests/sources/declarative/external_component.py index d9f0ca8c..6d609f97 100644 --- a/unit_tests/sources/declarative/external_component.py +++ b/unit_tests/sources/declarative/external_component.py @@ -1,13 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_cdk.sources.declarative.requesters import HttpRequester class SampleCustomComponent(HttpRequester): - """ - A test class used to validate manifests that rely on custom defined Python components - """ + """A test class used to validate manifests that rely on custom defined Python components""" pass diff --git a/unit_tests/sources/declarative/extractors/test_dpath_extractor.py b/unit_tests/sources/declarative/extractors/test_dpath_extractor.py index c5c40dd2..02205a89 100644 --- a/unit_tests/sources/declarative/extractors/test_dpath_extractor.py +++ b/unit_tests/sources/declarative/extractors/test_dpath_extractor.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import io import json -from typing import Dict, List, Union import pytest import requests + from airbyte_cdk import Decoder from airbyte_cdk.sources.declarative.decoders.json_decoder import ( IterableDecoder, @@ -15,6 +17,7 @@ ) from airbyte_cdk.sources.declarative.extractors.dpath_extractor import DpathExtractor + config = {"field": "record_array"} parameters = {"parameters_field": "record_array"} @@ -23,7 +26,7 @@ decoder_iterable = IterableDecoder(parameters={}) -def create_response(body: Union[Dict, bytes]): +def create_response(body: dict | bytes): response = requests.Response() response.raw = io.BytesIO(body if isinstance(body, bytes) else json.dumps(body).encode("utf-8")) return response @@ -111,7 +114,7 @@ def create_response(body: Union[Dict, bytes]): "test_extract_from_string_per_line_iterable", ], ) -def test_dpath_extractor(field_path: List, decoder: Decoder, body, expected_records: List): +def test_dpath_extractor(field_path: list, decoder: Decoder, body, expected_records: list): extractor = DpathExtractor( field_path=field_path, config=config, decoder=decoder, parameters=parameters ) diff --git a/unit_tests/sources/declarative/extractors/test_record_filter.py b/unit_tests/sources/declarative/extractors/test_record_filter.py index c4824c64..c1a9e9ea 100644 --- a/unit_tests/sources/declarative/extractors/test_record_filter.py +++ b/unit_tests/sources/declarative/extractors/test_record_filter.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import List, Mapping, Optional +from __future__ import annotations + +from collections.abc import Mapping import pytest + from airbyte_cdk.sources.declarative.datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.extractors.record_filter import ( ClientSideIncrementalRecordFilterDecorator, @@ -24,6 +27,7 @@ from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter from airbyte_cdk.sources.declarative.types import StreamSlice + DATE_FORMAT = "%Y-%m-%d" RECORDS_TO_FILTER_DATE_FORMAT = [ {"id": 1, "created_at": "2020-01-03"}, @@ -110,7 +114,7 @@ ], ) def test_record_filter( - filter_template: str, records: List[Mapping], expected_records: List[Mapping] + filter_template: str, records: list[Mapping], expected_records: list[Mapping] ): config = {"response_override": "stop_if_you_see_me"} parameters = {"created_at": "06-07-21"} @@ -264,11 +268,11 @@ def test_record_filter( ) def test_client_side_record_filter_decorator_no_parent_stream( datetime_format: str, - stream_state: Optional[Mapping], + stream_state: Mapping | None, record_filter_expression: str, - end_datetime: Optional[str], - records_to_filter: List[Mapping], - expected_record_ids: List[int], + end_datetime: str | None, + records_to_filter: list[Mapping], + expected_record_ids: list[int], ): date_time_based_cursor = DatetimeBasedCursor( start_datetime=MinMaxDatetime( @@ -367,7 +371,7 @@ def test_client_side_record_filter_decorator_no_parent_stream( ], ) def test_client_side_record_filter_decorator_with_cursor_types( - stream_state: Optional[Mapping], cursor_type: str, expected_record_ids: List[int] + stream_state: Mapping | None, cursor_type: str, expected_record_ids: list[int] ): def date_time_based_cursor_factory() -> DatetimeBasedCursor: return DatetimeBasedCursor( diff --git a/unit_tests/sources/declarative/extractors/test_record_selector.py b/unit_tests/sources/declarative/extractors/test_record_selector.py index a83586f7..53852726 100644 --- a/unit_tests/sources/declarative/extractors/test_record_selector.py +++ b/unit_tests/sources/declarative/extractors/test_record_selector.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json from unittest.mock import Mock, call import pytest import requests + from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder from airbyte_cdk.sources.declarative.extractors.dpath_extractor import DpathExtractor from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter diff --git a/unit_tests/sources/declarative/extractors/test_response_to_file_extractor.py b/unit_tests/sources/declarative/extractors/test_response_to_file_extractor.py index 98251df6..62a2adf5 100644 --- a/unit_tests/sources/declarative/extractors/test_response_to_file_extractor.py +++ b/unit_tests/sources/declarative/extractors/test_response_to_file_extractor.py @@ -1,4 +1,6 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations + import csv import os from io import BytesIO @@ -8,6 +10,7 @@ import pytest import requests import requests_mock + from airbyte_cdk.sources.declarative.extractors import ResponseToFileExtractor diff --git a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py index 7b651e25..ff598957 100644 --- a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import unittest import pytest + from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -15,6 +17,7 @@ ) from airbyte_cdk.sources.types import Record, StreamSlice + datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" cursor_granularity = "PT0.000001S" FAKE_NOW = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) @@ -33,7 +36,7 @@ def now(cls, tz=None): return FAKE_NOW -@pytest.fixture() +@pytest.fixture def mock_datetime_now(monkeypatch): monkeypatch.setattr(datetime, "datetime", MockedNowDatetime) @@ -897,7 +900,7 @@ def test_request_option_with_empty_stream_slice(stream_slice): config=config, parameters={}, ) - assert {} == slicer.get_request_params(stream_slice=stream_slice) + assert slicer.get_request_params(stream_slice=stream_slice) == {} @pytest.mark.parametrize( diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index e1cd6d19..655543e7 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from collections import OrderedDict from unittest.mock import Mock import pytest + from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( PerPartitionCursor, @@ -15,6 +17,7 @@ from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record + PARTITION = { "partition_key string": "partition value", "partition_key int": 1, @@ -72,8 +75,7 @@ def test_partition_with_different_key_orders(): def test_given_tuples_in_json_then_deserialization_convert_to_list(): - """ - This is a known issue with the current implementation. However, the assumption is that this wouldn't be a problem as we only use the + """This is a known issue with the current implementation. However, the assumption is that this wouldn't be a problem as we only use the immutability and we expect stream slices to be immutable anyway """ serializer = PerPartitionKeySerializer() @@ -119,12 +121,12 @@ def build(self): return cursor -@pytest.fixture() +@pytest.fixture def mocked_partition_router(): return Mock(spec=PartitionRouter) -@pytest.fixture() +@pytest.fixture def mocked_cursor_factory(): cursor_factory = Mock() cursor_factory.create.return_value = MockedCursorBuilder().build() diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py index 9d0216ff..d1302ca4 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from unittest.mock import MagicMock, patch +from orjson import orjson + from airbyte_cdk.models import ( AirbyteStateBlob, AirbyteStateMessage, @@ -24,7 +27,7 @@ from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import Record -from orjson import orjson + CURSOR_FIELD = "cursor_field" SYNC_MODE = SyncMode.incremental @@ -325,8 +328,7 @@ def test_substream_without_input_state(): def test_partition_limitation(caplog): - """ - Test that when the number of partitions exceeds the maximum allowed limit in PerPartitionCursor, + """Test that when the number of partitions exceeds the maximum allowed limit in PerPartitionCursor, the oldest partitions are dropped, and the state is updated accordingly. In this test, we set the maximum number of partitions to 2 and provide 3 partitions. @@ -454,8 +456,7 @@ def test_partition_limitation(caplog): def test_perpartition_with_fallback(caplog): - """ - Test that when the number of partitions exceeds the limit in PerPartitionCursor, + """Test that when the number of partitions exceeds the limit in PerPartitionCursor, the cursor falls back to using the global cursor for state management. This test also checks that the appropriate warning logs are emitted when the partition limit is exceeded. @@ -604,8 +605,7 @@ def test_perpartition_with_fallback(caplog): def test_per_partition_cursor_within_limit(caplog): - """ - Test that the PerPartitionCursor correctly updates the state for each partition + """Test that the PerPartitionCursor correctly updates the state for each partition when the number of partitions is within the allowed limit. This test also checks that no warning logs are emitted when the partition limit is not exceeded. diff --git a/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py b/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py index 90321449..d59f7c10 100644 --- a/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.incremental import ( ChildPartitionResumableFullRefreshCursor, ResumableFullRefreshCursor, diff --git a/unit_tests/sources/declarative/interpolation/test_filters.py b/unit_tests/sources/declarative/interpolation/test_filters.py index 82dd2bf1..c5e60bf4 100644 --- a/unit_tests/sources/declarative/interpolation/test_filters.py +++ b/unit_tests/sources/declarative/interpolation/test_filters.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import base64 import hashlib import pytest + from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation + interpolation = JinjaInterpolation() diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py b/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py index 015d45aa..7b6eb2fb 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean + config = { "parent": {"key_with_true": True}, "string_key": "compare_me", diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py b/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py index 87843915..0ea09466 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping @@ -42,6 +44,6 @@ def test(test_name, key, expected_value): kwargs = {"a": "VALUE_FROM_KWARGS"} mapping = InterpolatedMapping(mapping=d, parameters={"b": "VALUE_FROM_PARAMETERS", "k": "key"}) - interpolated = mapping.eval(config, **{"kwargs": kwargs}) + interpolated = mapping.eval(config, kwargs=kwargs) assert interpolated[key] == expected_value diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py b/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py index 809f368c..85e967a5 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import dpath import pytest + from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import ( InterpolatedNestedMapping, ) @@ -48,6 +50,6 @@ def test(test_name, path, expected_value): mapping=d, parameters={"b": "VALUE_FROM_PARAMETERS", "k": "key"} ) - interpolated = mapping.eval(config, **{"kwargs": kwargs}) + interpolated = mapping.eval(config, kwargs=kwargs) assert dpath.get(interpolated, path) == expected_value diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_string.py b/unit_tests/sources/declarative/interpolation/test_interpolated_string.py index f0f1a995..b1f79205 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_string.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_string.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString + config = {"field": "value"} parameters = {"hello": "world"} kwargs = {"c": "airbyte"} @@ -22,4 +25,4 @@ ) def test_interpolated_string(test_name, input_string, expected_value): s = InterpolatedString.create(input_string, parameters=parameters) - assert s.eval(config, **{"kwargs": kwargs}) == expected_value + assert s.eval(config, kwargs=kwargs) == expected_value diff --git a/unit_tests/sources/declarative/interpolation/test_jinja.py b/unit_tests/sources/declarative/interpolation/test_jinja.py index 7534e929..d77f49a3 100644 --- a/unit_tests/sources/declarative/interpolation/test_jinja.py +++ b/unit_tests/sources/declarative/interpolation/test_jinja.py @@ -1,15 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import pytest -from airbyte_cdk import StreamSlice -from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from freezegun import freeze_time from jinja2.exceptions import TemplateSyntaxError +from airbyte_cdk import StreamSlice +from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation + + interpolation = JinjaInterpolation() @@ -45,7 +48,7 @@ def test_get_value_from_stream_slice(): s = "{{ stream_slice['date'] }}" config = {"date": "2022-01-01"} stream_slice = {"date": "2020-09-09"} - val = interpolation.eval(s, config, **{"stream_slice": stream_slice}) + val = interpolation.eval(s, config, stream_slice=stream_slice) assert val == "2020-09-09" @@ -53,7 +56,7 @@ def test_get_missing_value_from_stream_slice(): s = "{{ stream_slice['date'] }}" config = {"date": "2022-01-01"} stream_slice = {} - val = interpolation.eval(s, config, **{"stream_slice": stream_slice}) + val = interpolation.eval(s, config, stream_slice=stream_slice) assert val is None @@ -61,7 +64,7 @@ def test_get_value_from_a_list_of_mappings(): s = "{{ records[0]['date'] }}" config = {"date": "2022-01-01"} records = [{"date": "2020-09-09"}] - val = interpolation.eval(s, config, **{"records": records}) + val = interpolation.eval(s, config, records=records) assert val == "2020-09-09" @@ -249,10 +252,10 @@ def test_undeclared_variables(template_string, expected_error, expected_value): if expected_error: with pytest.raises(expected_error): - interpolation.eval(template_string, config=config, **{"to_be": "that_is_the_question"}) + interpolation.eval(template_string, config=config, to_be="that_is_the_question") else: actual_value = interpolation.eval( - template_string, config=config, **{"to_be": "that_is_the_question"} + template_string, config=config, to_be="that_is_the_question" ) assert actual_value == expected_value @@ -340,10 +343,10 @@ def test_macros_timezone(template_string: str, expected_value: str): def test_interpolation_private_partition_attribute(): inner_partition = StreamSlice(partition={}, cursor_slice={}) expected_output = "value" - setattr(inner_partition, "parent_stream_fields", expected_output) + inner_partition.parent_stream_fields = expected_output stream_slice = StreamSlice(partition=inner_partition, cursor_slice={}) template = "{{ stream_slice._partition.parent_stream_fields }}" - actual_output = JinjaInterpolation().eval(template, {}, **{"stream_slice": stream_slice}) + actual_output = JinjaInterpolation().eval(template, {}, stream_slice=stream_slice) assert actual_output == expected_output diff --git a/unit_tests/sources/declarative/interpolation/test_macros.py b/unit_tests/sources/declarative/interpolation/test_macros.py index cd16bd9f..3247a577 100644 --- a/unit_tests/sources/declarative/interpolation/test_macros.py +++ b/unit_tests/sources/declarative/interpolation/test_macros.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import pytest + from airbyte_cdk.sources.declarative.interpolation.macros import macros @@ -114,7 +116,5 @@ def test_timestamp(test_name, input_value, expected_output): def test_utc_datetime_to_local_timestamp_conversion(): - """ - This test ensures correct timezone handling independent of the timezone of the system on which the sync is running. - """ + """This test ensures correct timezone handling independent of the timezone of the system on which the sync is running.""" assert macros["format_datetime"](dt="2020-10-01T00:00:00Z", format="%s") == "1601510400" diff --git a/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py b/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py index 442e444a..fc0087fb 100644 --- a/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py +++ b/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import ( LegacyToPerPartitionStateMigration, ) @@ -13,15 +15,13 @@ CustomRetriever, DatetimeBasedCursor, DeclarativeStream, -) -from airbyte_cdk.sources.declarative.models import ( - LegacyToPerPartitionStateMigration as LegacyToPerPartitionStateMigrationModel, -) -from airbyte_cdk.sources.declarative.models import ( ParentStreamConfig, SimpleRetriever, SubstreamPartitionRouter, ) +from airbyte_cdk.sources.declarative.models import ( + LegacyToPerPartitionStateMigration as LegacyToPerPartitionStateMigrationModel, +) from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( ManifestComponentTransformer, ) @@ -33,6 +33,7 @@ ) from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource + factory = ModelToComponentFactory() resolver = ManifestReferenceResolver() diff --git a/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py b/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py index 3cd06273..07540ab4 100644 --- a/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py +++ b/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( ManifestComponentTransformer, ) diff --git a/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py b/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py index 36ae03cc..ce4a5a15 100644 --- a/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py +++ b/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.parsers.custom_exceptions import ( CircularReferenceException, UndefinedReferenceException, @@ -12,6 +14,7 @@ _parse_path, ) + resolver = ManifestReferenceResolver() diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index c8d0781a..64f6f648 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -3,12 +3,21 @@ # # mypy: ignore-errors +from __future__ import annotations + import datetime -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import freezegun import pendulum import pytest + +from unit_tests.sources.declarative.parsers.testing_components import ( + TestingCustomSubstreamPartitionRouter, + TestingSomeComponent, +) + from airbyte_cdk import AirbyteTracedException from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager @@ -132,10 +141,7 @@ from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import ( SingleUseRefreshTokenOauth2Authenticator, ) -from unit_tests.sources.declarative.parsers.testing_components import ( - TestingCustomSubstreamPartitionRouter, - TestingSomeComponent, -) + factory = ModelToComponentFactory() @@ -2148,7 +2154,7 @@ def test_no_transformations(self): ) assert isinstance(stream, DeclarativeStream) - assert [] == stream.retriever.record_selector.transformations + assert stream.retriever.record_selector.transformations == [] def test_remove_fields(self): content = f""" @@ -3201,8 +3207,7 @@ def test_create_concurrent_cursor_from_datetime_based_cursor( def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): - """ - Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to + """Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to string parser should not inherit from the parent DatetimeBasedCursor.datetime_format. The parent which uses an incorrect precision would fail if it were used by the dependent children. """ diff --git a/unit_tests/sources/declarative/parsers/testing_components.py b/unit_tests/sources/declarative/parsers/testing_components.py index 0d49e862..3401b04b 100644 --- a/unit_tests/sources/declarative/parsers/testing_components.py +++ b/unit_tests/sources/declarative/parsers/testing_components.py @@ -1,9 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional from airbyte_cdk.sources.declarative.extractors import DpathExtractor from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter @@ -17,25 +17,21 @@ @dataclass class TestingSomeComponent(DefaultErrorHandler): - """ - A basic test class with various field permutations used to test manifests with custom components - """ + """A basic test class with various field permutations used to test manifests with custom components""" subcomponent_field_with_hint: DpathExtractor = field( default_factory=lambda: DpathExtractor(field_path=[], config={}, parameters={}) ) basic_field: str = "" - optional_subcomponent_field: Optional[RequestOption] = None - list_of_subcomponents: List[RequestOption] = None + optional_subcomponent_field: RequestOption | None = None + list_of_subcomponents: list[RequestOption] = None without_hint = None paginator: DefaultPaginator = None @dataclass class TestingCustomSubstreamPartitionRouter(SubstreamPartitionRouter): - """ - A test class based on a SubstreamPartitionRouter used for testing manifests that use custom components. - """ + """A test class based on a SubstreamPartitionRouter used for testing manifests that use custom components.""" custom_field: str custom_pagination_strategy: PaginationStrategy diff --git a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py index 697a0605..d7322991 100644 --- a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest as pytest + from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -319,7 +321,7 @@ def test_request_option_before_updating_cursor(): ], parameters={}, ) - assert {} == slicer.get_request_params() - assert {} == slicer.get_request_headers() - assert {} == slicer.get_request_body_json() - assert {} == slicer.get_request_body_data() + assert slicer.get_request_params() == {} + assert slicer.get_request_headers() == {} + assert slicer.get_request_body_json() == {} + assert slicer.get_request_body_data() == {} diff --git a/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py index baa3ad8d..6ab5bdef 100644 --- a/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest as pytest + from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ( ListPartitionRouter, ) @@ -12,6 +14,7 @@ ) from airbyte_cdk.sources.types import StreamSlice + partition_values = ["customer", "store", "subscription"] cursor_field = "owner_resource" parameters = {"cursor_field": "owner_resource"} @@ -152,7 +155,7 @@ def test_request_option_is_empty_if_no_stream_slice(stream_slice): request_option=request_option, parameters={}, ) - assert {} == partition_router.get_request_body_data(stream_slice=stream_slice) + assert partition_router.get_request_body_data(stream_slice=stream_slice) == {} @pytest.mark.parametrize( @@ -201,7 +204,7 @@ def test_request_option_before_updating_cursor(): ) stream_slice = {cursor_field: "customer"} - assert {} == partition_router.get_request_params(stream_slice) - assert {} == partition_router.get_request_headers() - assert {} == partition_router.get_request_body_json() - assert {} == partition_router.get_request_body_data() + assert partition_router.get_request_params(stream_slice) == {} + assert partition_router.get_request_headers() == {} + assert partition_router.get_request_body_json() == {} + assert partition_router.get_request_body_data() == {} diff --git a/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py b/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py index 81de362d..16c53118 100644 --- a/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py +++ b/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py @@ -1,11 +1,15 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import copy -from typing import Any, List, Mapping, MutableMapping, Optional, Union +from collections.abc import Mapping, MutableMapping +from typing import Any from unittest.mock import MagicMock import pytest import requests_mock +from orjson import orjson + from airbyte_cdk.models import ( AirbyteMessage, AirbyteStateBlob, @@ -20,7 +24,7 @@ SyncMode, ) from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource -from orjson import orjson + SUBSTREAM_MANIFEST: MutableMapping[str, Any] = { "version": "0.51.42", @@ -237,8 +241,8 @@ def _run_read( manifest: Mapping[str, Any], config: Mapping[str, Any], stream_name: str, - state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]] = None, -) -> List[AirbyteMessage]: + state: list[AirbyteStateMessage] | MutableMapping[str, Any] | None = None, +) -> list[AirbyteMessage]: source = ManifestDeclarativeSource(source_config=manifest) catalog = ConfiguredAirbyteCatalog( streams=[ @@ -260,8 +264,7 @@ def _run_read( def run_incremental_parent_state_test( manifest, mock_requests, expected_records, initial_state, expected_states ): - """ - Run an incremental parent state test for the specified stream. + """Run an incremental parent state test for the specified stream. This function performs the following steps: 1. Mocks the API requests as defined in mock_requests. @@ -849,9 +852,7 @@ def test_incremental_parent_state( def test_incremental_parent_state_migration( test_name, manifest, mock_requests, expected_records, initial_state, expected_state ): - """ - Test incremental partition router with parent state migration - """ + """Test incremental partition router with parent state migration""" _stream_name = "post_comment_votes" config = { "start_date": "2024-01-01T00:00:01Z", @@ -1047,9 +1048,7 @@ def test_incremental_parent_state_migration( def test_incremental_parent_state_no_slices( test_name, manifest, mock_requests, expected_records, initial_state, expected_state ): - """ - Test incremental partition router with no parent records - """ + """Test incremental partition router with no parent records""" _stream_name = "post_comment_votes" config = { "start_date": "2024-01-01T00:00:01Z", @@ -1254,9 +1253,7 @@ def test_incremental_parent_state_no_slices( def test_incremental_parent_state_no_records( test_name, manifest, mock_requests, expected_records, initial_state, expected_state ): - """ - Test incremental partition router with no child records - """ + """Test incremental partition router with no child records""" _stream_name = "post_comment_votes" config = { "start_date": "2024-01-01T00:00:01Z", @@ -1489,8 +1486,7 @@ def test_incremental_parent_state_no_records( def test_incremental_parent_state_no_incremental_dependency( test_name, manifest, mock_requests, expected_records, initial_state, expected_state ): - """ - This is a pretty complicated test that syncs a low-code connector stream with three levels of substreams + """This is a pretty complicated test that syncs a low-code connector stream with three levels of substreams - posts: (ids: 1, 2, 3) - post comments: (parent post 1 with ids: 9, 10, 11, 12; parent post 2 with ids: 20, 21; parent post 3 with id: 30) - post comment votes: (parent comment 10 with ids: 100, 101; parent comment 11 with id: 102; @@ -1501,7 +1497,6 @@ def test_incremental_parent_state_no_incremental_dependency( parent stream requests use the incoming config as query parameters and the substream state messages does not contain parent stream state. """ - _stream_name = "post_comment_votes" config = { "start_date": "2024-01-01T00:00:01Z", diff --git a/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py index 82cc7ba3..ae9cbd0a 100644 --- a/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( SinglePartitionRouter, diff --git a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py index f42bd554..fe5d062f 100644 --- a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, Mapping, MutableMapping from functools import partial -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any import pytest as pytest + from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, SyncMode, Type from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.incremental import ( @@ -35,6 +38,7 @@ from airbyte_cdk.sources.types import Record from airbyte_cdk.utils import AirbyteTracedException + parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}] more_records = [ {"id": 10, "data": "data10", "slice": "second_parent"}, @@ -82,7 +86,7 @@ def name(self) -> str: return self._name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return "id" @property @@ -97,16 +101,16 @@ def state(self, value: Mapping[str, Any]) -> None: def is_resumable(self) -> bool: return bool(self._cursor) - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: return self._cursor def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_state: Mapping[str, Any] = None, - ) -> Iterable[Optional[StreamSlice]]: + ) -> Iterable[StreamSlice | None]: for s in self._slices: if isinstance(s, StreamSlice): yield s @@ -116,7 +120,7 @@ def stream_slices( def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -156,7 +160,7 @@ def __init__(self, slices, records, name, cursor_field="", cursor=None, date_ran def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -178,7 +182,7 @@ def __init__( name, cursor_field="", cursor=None, - record_pages: Optional[List[List[Mapping[str, Any]]]] = None, + record_pages: list[list[Mapping[str, Any]]] | None = None, ): super().__init__(slices, [], name, cursor_field, cursor) if record_pages: @@ -190,7 +194,7 @@ def __init__( def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -750,8 +754,7 @@ def test_substream_using_incremental_parent_stream(): def test_substream_checkpoints_after_each_parent_partition(): - """ - This test validates the specific behavior that when getting all parent records for a substream, + """This test validates the specific behavior that when getting all parent records for a substream, we are still updating state so that the parent stream's state is updated after we finish getting all parent records for the parent slice (not just the substream) """ diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py index eb2ecc1d..a45148d2 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ( ConstantBackoffStrategy, ) + BACKOFF_TIME = 10 PARAMETERS_BACKOFF_TIME = 20 CONFIG_BACKOFF_TIME = 30 diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py index 3e5b4c90..4985aacd 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py @@ -1,14 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.exponential_backoff_strategy import ( ExponentialBackoffStrategy, ) + parameters = {"backoff": 5} config = {"backoff": 5} diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py index 99af3ca2..c5adfaf5 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import re from unittest.mock import MagicMock import pytest + from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( get_numeric_value_from_header, ) @@ -23,23 +25,23 @@ "test_get_numeric_value_with_regex", {"header": "61,60"}, "header", - re.compile("([-+]?\d+)"), + re.compile(r"([-+]?\d+)"), 61, - ), # noqa + ), ( "test_get_numeric_value_with_regex_no_header", {"header": "61,60"}, "notheader", - re.compile("([-+]?\d+)"), + re.compile(r"([-+]?\d+)"), None, - ), # noqa + ), ( "test_get_numeric_value_with_regex_not_matching", {"header": "abc61,60"}, "header", - re.compile("([-+]?\d+)"), + re.compile(r"([-+]?\d+)"), None, - ), # noqa + ), ], ) def test_get_numeric_value_from_header(test_name, headers, requested_header, regex, expected_value): diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py index 6db2f9fd..e61574f8 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py @@ -1,16 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest +from requests import Response + from airbyte_cdk import AirbyteTracedException from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.wait_time_from_header_backoff_strategy import ( WaitTimeFromHeaderBackoffStrategy, ) -from requests import Response + SOME_BACKOFF_TIME = 60 _A_RETRY_HEADER = "retry-header" diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py index 20dba620..1ec5fd88 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py @@ -1,16 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest import requests + from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.wait_until_time_from_header_backoff_strategy import ( WaitUntilTimeFromHeaderBackoffStrategy, ) + SOME_BACKOFF_TIME = 60 REGEX = "[-+]?\\d+" @@ -52,9 +54,9 @@ "wait_until", "1600000060,60", None, - "[-+]?\d+", + r"[-+]?\d+", 60, - ), # noqa + ), ( "test_wait_until_time_from_header_with_regex_from_parameters", "wait_until", @@ -63,7 +65,6 @@ "{{parameters['regex']}}", 60, ), - # noqa ( "test_wait_until_time_from_header_with_regex_from_config", "wait_until", @@ -71,15 +72,15 @@ None, "{{config['regex']}}", 60, - ), # noqa + ), ( "test_wait_until_time_from_header_with_regex_no_match", "wait_time", "...", None, - "[-+]?\d+", + r"[-+]?\d+", None, - ), # noqa + ), ( "test_wait_until_no_header_with_min", "absent_header", diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py b/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py index 3d2f551c..3b5915e8 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers import HttpResponseFilter from airbyte_cdk.sources.declarative.requesters.error_handlers.composite_error_handler import ( @@ -19,6 +21,7 @@ ResponseAction, ) + SOME_BACKOFF_TIME = 60 diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py b/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py index f97fa05f..8e5c36f4 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest import requests + from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ( ConstantBackoffStrategy, ) @@ -25,6 +27,7 @@ ResponseAction, ) + SOME_BACKOFF_TIME = 60 diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py b/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py index dc0c004a..a3fb363a 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest +from requests import RequestException, Response + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import ( DefaultHttpResponseFilter, @@ -13,7 +16,6 @@ DEFAULT_ERROR_MAPPING, ) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction -from requests import RequestException, Response @pytest.mark.parametrize( diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py b/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py index 1acc95f2..c9f09bb1 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import pytest import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers import HttpResponseFilter from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( diff --git a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py index 4d2920ea..10769922 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import pytest import requests + from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import ( diff --git a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py index d02562b0..cc21fc16 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json from unittest.mock import MagicMock import pytest import requests + from airbyte_cdk.sources.declarative.decoders import JsonDecoder, XmlDecoder from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.requesters.paginators.default_paginator import ( diff --git a/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py index 12b81010..0ac45ee8 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import requests + from airbyte_cdk.sources.declarative.requesters.paginators.no_pagination import NoPagination diff --git a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py index 8c357349..7b0e9eea 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json -from typing import Any, Optional +from typing import Any import pytest import requests + from airbyte_cdk.sources.declarative.requesters.paginators.strategies.offset_increment import ( OffsetIncrement, ) @@ -55,7 +57,7 @@ def test_offset_increment_paginator_strategy( assert expected_offset == paginator_strategy._offset paginator_strategy.reset() - assert 0 == paginator_strategy._offset + assert paginator_strategy._offset == 0 def test_offset_increment_paginator_strategy_rises(): @@ -77,7 +79,7 @@ def test_offset_increment_paginator_strategy_rises(): ], ) def test_offset_increment_paginator_strategy_initial_token( - inject_on_first_request: bool, expected_initial_token: Optional[Any] + inject_on_first_request: bool, expected_initial_token: Any | None ): paginator_strategy = OffsetIncrement( page_size=20, parameters={}, config={}, inject_on_first_request=inject_on_first_request diff --git a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py index 9ec994e2..b869680f 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json -from typing import Any, Optional +from typing import Any import pytest import requests + from airbyte_cdk.sources.declarative.requesters.paginators.strategies.page_increment import ( PageIncrement, ) @@ -70,7 +72,7 @@ def test_page_increment_paginator_strategy_malformed_page_size(page_size): ], ) def test_page_increment_paginator_strategy_initial_token( - inject_on_first_request: bool, start_from_page: int, expected_initial_token: Optional[Any] + inject_on_first_request: bool, start_from_page: int, expected_initial_token: Any | None ): paginator_strategy = PageIncrement( page_size=20, diff --git a/unit_tests/sources/declarative/requesters/paginators/test_request_option.py b/unit_tests/sources/declarative/requesters/paginators/test_request_option.py index cef4fe87..c0f88c92 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_request_option.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_request_option.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.requesters.request_option import ( RequestOption, RequestOptionType, diff --git a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py index 201636f1..4df26915 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import Mock, call +from pytest import fixture + from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( PaginationStrategy, @@ -14,7 +17,7 @@ StopConditionPaginationStrategyDecorator, ) from airbyte_cdk.sources.types import Record -from pytest import fixture + ANY_RECORD = Mock() NO_RECORD = None diff --git a/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py b/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py index 7cbfa78d..8c01f937 100644 --- a/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py +++ b/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py @@ -1,8 +1,10 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.requesters.request_option import ( RequestOption, RequestOptionType, diff --git a/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py b/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py index 3e11bfa5..ce98bab8 100644 --- a/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py +++ b/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import ( InterpolatedRequestOptionsProvider, ) + state = {"date": "2021-01-01"} stream_slice = {"start_date": "2020-01-01"} next_page_token = {"offset": 12345, "page": 27} diff --git a/unit_tests/sources/declarative/requesters/test_http_job_repository.py b/unit_tests/sources/declarative/requesters/test_http_job_repository.py index aa2a13f7..fc68b639 100644 --- a/unit_tests/sources/declarative/requesters/test_http_job_repository.py +++ b/unit_tests/sources/declarative/requesters/test_http_job_repository.py @@ -1,11 +1,12 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - +from __future__ import annotations import json from unittest import TestCase from unittest.mock import Mock import pytest + from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.decoders import NoopDecoder from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder @@ -31,6 +32,7 @@ from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse + _ANY_CONFIG = {} _ANY_SLICE = StreamSlice(partition={}, cursor_slice={}) _URL_BASE = "https://api.sendgrid.com/v3/" diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index 1428319f..429c18c5 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -1,14 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from unittest import mock from unittest.mock import MagicMock from urllib.parse import parse_qs, urlparse import pytest as pytest import requests +from requests import PreparedRequest + from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies import ( @@ -29,7 +33,6 @@ UserDefinedBackoffException, ) from airbyte_cdk.sources.types import Config -from requests import PreparedRequest @pytest.fixture @@ -39,13 +42,13 @@ def factory( url_base: str = "https://test_base_url.com", path: str = "/", http_method: str = HttpMethod.GET, - request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None, - authenticator: Optional[DeclarativeAuthenticator] = None, - error_handler: Optional[ErrorHandler] = None, - config: Optional[Config] = None, + request_options_provider: InterpolatedRequestOptionsProvider | None = None, + authenticator: DeclarativeAuthenticator | None = None, + error_handler: ErrorHandler | None = None, + config: Config | None = None, parameters: Mapping[str, Any] = None, disable_retries: bool = False, - message_repository: Optional[MessageRepository] = None, + message_repository: MessageRepository | None = None, use_cache: bool = False, ) -> HttpRequester: return HttpRequester( @@ -179,12 +182,12 @@ def test_path(test_name, path, expected_path): def create_requester( - url_base: Optional[str] = None, - parameters: Optional[Mapping[str, Any]] = {}, - config: Optional[Config] = None, - path: Optional[str] = None, - authenticator: Optional[DeclarativeAuthenticator] = None, - error_handler: Optional[ErrorHandler] = None, + url_base: str | None = None, + parameters: Mapping[str, Any] | None = {}, + config: Config | None = None, + path: str | None = None, + authenticator: DeclarativeAuthenticator | None = None, + error_handler: ErrorHandler | None = None, ) -> HttpRequester: requester = HttpRequester( name="name", diff --git a/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py b/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py index 8882e918..425abde5 100644 --- a/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py +++ b/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest as pytest + from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_input_provider import ( InterpolatedRequestInputProvider, diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index d2eb2d15..260f4f1b 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json from unittest.mock import MagicMock, Mock, patch import pytest import requests + from airbyte_cdk import YamlDeclarativeSource from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type from airbyte_cdk.sources.declarative.auth.declarative_authenticator import NoAuth @@ -30,6 +32,7 @@ ) from airbyte_cdk.sources.types import Record, StreamSlice + A_SLICE_STATE = {"slice_state": "slice state value"} A_STREAM_SLICE = StreamSlice(cursor_slice={"stream slice": "slice value"}, partition={}) A_STREAM_STATE = {"stream state": "state value"} diff --git a/unit_tests/sources/declarative/schema/source_test/SourceTest.py b/unit_tests/sources/declarative/schema/source_test/SourceTest.py index 8d6a26cb..7191ad0e 100644 --- a/unit_tests/sources/declarative/schema/source_test/SourceTest.py +++ b/unit_tests/sources/declarative/schema/source_test/SourceTest.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations class SourceTest: diff --git a/unit_tests/sources/declarative/schema/test_default_schema_loader.py b/unit_tests/sources/declarative/schema/test_default_schema_loader.py index 38e617f9..c060be73 100644 --- a/unit_tests/sources/declarative/schema/test_default_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_default_schema_loader.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader diff --git a/unit_tests/sources/declarative/schema/test_inline_schema_loader.py b/unit_tests/sources/declarative/schema/test_inline_schema_loader.py index ad44ee33..44d68f85 100644 --- a/unit_tests/sources/declarative/schema/test_inline_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_inline_schema_loader.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader diff --git a/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py b/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py index a53a88a9..2e68bf03 100644 --- a/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py @@ -1,9 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from unittest.mock import patch import pytest + from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import ( JsonFileSchemaLoader, _default_file_path, @@ -51,7 +54,7 @@ def test_exclude_cdk_packages(mocked_sys): "airbyte_cdk.sources.concurrent_source.concurrent_source_adapter", "source_gitlab.utils", ] - mocked_sys.modules = {key: "" for key in keys} + mocked_sys.modules = dict.fromkeys(keys, "") default_file_path = _default_file_path() diff --git a/unit_tests/sources/declarative/spec/test_spec.py b/unit_tests/sources/declarative/spec/test_spec.py index 8b924cb4..952688d6 100644 --- a/unit_tests/sources/declarative/spec/test_spec.py +++ b/unit_tests/sources/declarative/spec/test_spec.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.models import AdvancedAuth, AuthFlowType, ConnectorSpecification from airbyte_cdk.sources.declarative.models.declarative_component_schema import AuthFlow from airbyte_cdk.sources.declarative.spec.spec import Spec diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index bc97b2c5..8f1005e2 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -1,15 +1,19 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy import json +from collections.abc import Iterable, Mapping from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any import freezegun import isodate import pendulum +from deprecated.classic import deprecated + from airbyte_cdk.models import ( AirbyteMessage, AirbyteRecordMessage, @@ -38,7 +42,7 @@ from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse from airbyte_cdk.utils import AirbyteTracedException -from deprecated.classic import deprecated + _CONFIG = {"start_date": "2024-07-01T00:00:00.000Z"} @@ -367,8 +371,7 @@ @deprecated("See note in docstring for more information") class DeclarativeStreamDecorator(Stream): - """ - Helper class that wraps an existing DeclarativeStream but allows for overriding the output of read_records() to + """Helper class that wraps an existing DeclarativeStream but allows for overriding the output of read_records() to make it easier to mock behavior and test how low-code streams integrate with the Concurrent CDK. NOTE: We are not using that for now but the intent was to scope the tests to only testing that streams were properly instantiated and @@ -382,7 +385,7 @@ class DeclarativeStreamDecorator(Stream): def __init__( self, declarative_stream: DeclarativeStream, - slice_to_records_mapping: Mapping[tuple[str, str], List[Mapping[str, Any]]], + slice_to_records_mapping: Mapping[tuple[str, str], list[Mapping[str, Any]]], ): self._declarative_stream = declarative_stream self._slice_to_records_mapping = slice_to_records_mapping @@ -392,15 +395,15 @@ def name(self) -> str: return self._declarative_stream.name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return self._declarative_stream.primary_key def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: if isinstance(stream_slice, StreamSlice): slice_key = (stream_slice.get("start_time"), stream_slice.get("end_time")) @@ -424,15 +427,12 @@ def read_records( def get_json_schema(self) -> Mapping[str, Any]: return self._declarative_stream.get_json_schema() - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: return self._declarative_stream.get_cursor() def test_group_streams(): - """ - Tests the grouping of low-code streams into ones that can be processed concurrently vs ones that must be processed concurrently - """ - + """Tests the grouping of low-code streams into ones that can be processed concurrently vs ones that must be processed concurrently""" catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( @@ -497,11 +497,9 @@ def test_group_streams(): @freezegun.freeze_time(time_to_freeze=datetime(2024, 9, 1, 0, 0, 0, 0, tzinfo=timezone.utc)) def test_create_concurrent_cursor(): - """ - Validate that the ConcurrentDeclarativeSource properly instantiates a ConcurrentCursor from the + """Validate that the ConcurrentDeclarativeSource properly instantiates a ConcurrentCursor from the low-code DatetimeBasedCursor component """ - incoming_locations_state = { "slices": [ {"start": "2024-07-01T00:00:00", "end": "2024-07-31T00:00:00"}, @@ -566,9 +564,7 @@ def test_create_concurrent_cursor(): def test_check(): - """ - Verifies that the ConcurrentDeclarativeSource check command is run against synchronous streams - """ + """Verifies that the ConcurrentDeclarativeSource check command is run against synchronous streams""" with HttpMocker() as http_mocker: http_mocker.get( HttpRequest( @@ -605,9 +601,7 @@ def test_check(): def test_discover(): - """ - Verifies that the ConcurrentDeclarativeSource discover command returns concurrent and synchronous catalog definitions - """ + """Verifies that the ConcurrentDeclarativeSource discover command returns concurrent and synchronous catalog definitions""" expected_stream_names = ["party_members", "palaces", "locations", "party_members_skills"] source = ConcurrentDeclarativeSource( @@ -626,8 +620,8 @@ def test_discover(): def _mock_requests( http_mocker: HttpMocker, url: str, - query_params: List[Dict[str, str]], - responses: List[HttpResponse], + query_params: list[dict[str, str]], + responses: list[HttpResponse], ) -> None: assert len(query_params) == len(responses), "Expecting as many slices as response" @@ -636,7 +630,7 @@ def _mock_requests( def _mock_party_members_requests( - http_mocker: HttpMocker, slices_and_responses: List[Tuple[Dict[str, str], HttpResponse]] + http_mocker: HttpMocker, slices_and_responses: list[tuple[dict[str, str], HttpResponse]] ) -> None: slices = list(map(lambda slice_and_response: slice_and_response[0], slices_and_responses)) responses = list(map(lambda slice_and_response: slice_and_response[1], slices_and_responses)) @@ -649,7 +643,7 @@ def _mock_party_members_requests( ) -def _mock_locations_requests(http_mocker: HttpMocker, slices: List[Dict[str, str]]) -> None: +def _mock_locations_requests(http_mocker: HttpMocker, slices: list[dict[str, str]]) -> None: locations_query_params = list( map(lambda _slice: _slice | {"m": "active", "i": "1", "g": "country"}, slices) ) @@ -662,9 +656,7 @@ def _mock_locations_requests(http_mocker: HttpMocker, slices: List[Dict[str, str def _mock_party_members_skills_requests(http_mocker: HttpMocker) -> None: - """ - This method assumes _mock_party_members_requests has been called before else the stream won't work. - """ + """This method assumes _mock_party_members_requests has been called before else the stream won't work.""" http_mocker.get( HttpRequest("https://persona.metaverse.com/party_members/amamiya/skills"), _PARTY_MEMBERS_SKILLS_RESPONSE, @@ -681,9 +673,7 @@ def _mock_party_members_skills_requests(http_mocker: HttpMocker) -> None: @freezegun.freeze_time(_NOW) def test_read_with_concurrent_and_synchronous_streams(): - """ - Verifies that a ConcurrentDeclarativeSource processes concurrent streams followed by synchronous streams - """ + """Verifies that a ConcurrentDeclarativeSource processes concurrent streams followed by synchronous streams""" location_slices = [ {"start": "2024-07-01", "end": "2024-07-31"}, {"start": "2024-08-01", "end": "2024-08-31"}, @@ -805,8 +795,7 @@ def test_read_with_concurrent_and_synchronous_streams(): @freezegun.freeze_time(_NOW) def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): - """ - Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming + """Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming concurrent state format """ state = [ @@ -926,8 +915,7 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): @freezegun.freeze_time(_NOW) def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): - """ - Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming + """Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming legacy state format """ state = [ @@ -1050,9 +1038,7 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): @freezegun.freeze_time(_NOW) def test_read_concurrent_with_failing_partition_in_the_middle(): - """ - Verify that partial state is emitted when only some partitions are successful during a concurrent sync attempt - """ + """Verify that partial state is emitted when only some partitions are successful during a concurrent sync attempt""" location_slices = [ {"start": "2024-07-01", "end": "2024-07-31"}, # missing slice `{"start": "2024-08-01", "end": "2024-08-31"}` here @@ -1109,9 +1095,7 @@ def test_read_concurrent_with_failing_partition_in_the_middle(): @freezegun.freeze_time(_NOW) def test_read_concurrent_skip_streams_not_in_catalog(): - """ - Verifies that the ConcurrentDeclarativeSource only syncs streams that are specified in the incoming ConfiguredCatalog - """ + """Verifies that the ConcurrentDeclarativeSource only syncs streams that are specified in the incoming ConfiguredCatalog""" with HttpMocker() as http_mocker: catalog = ConfiguredAirbyteCatalog( streams=[ @@ -1425,7 +1409,7 @@ def create_wrapped_stream(stream: DeclarativeStream) -> Stream: ) -def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], List[StreamData]]: +def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], list[StreamData]]: match stream_name: case "locations": slices = [ @@ -1551,8 +1535,8 @@ def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], def get_records_for_stream( - stream_name: str, messages: List[AirbyteMessage] -) -> List[AirbyteRecordMessage]: + stream_name: str, messages: list[AirbyteMessage] +) -> list[AirbyteRecordMessage]: return [ message.record for message in messages @@ -1561,8 +1545,8 @@ def get_records_for_stream( def get_states_for_stream( - stream_name: str, messages: List[AirbyteMessage] -) -> List[AirbyteStateMessage]: + stream_name: str, messages: list[AirbyteMessage] +) -> list[AirbyteStateMessage]: return [ message.state for message in messages diff --git a/unit_tests/sources/declarative/test_declarative_stream.py b/unit_tests/sources/declarative/test_declarative_stream.py index f34fd137..0715d6fb 100644 --- a/unit_tests/sources/declarative/test_declarative_stream.py +++ b/unit_tests/sources/declarative/test_declarative_stream.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -17,6 +19,7 @@ from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.types import StreamSlice + SLICE_NOT_CONSIDERED_FOR_EQUALITY = {} _name = "stream" @@ -81,9 +84,7 @@ def test_declarative_stream(): def test_declarative_stream_using_empty_slice(): - """ - Tests that a declarative_stream - """ + """Tests that a declarative_stream""" schema_loader = _schema_loader() records = [ diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index 19be3a82..b684f558 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -1,19 +1,23 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging import os import sys +from collections.abc import Mapping from copy import deepcopy from pathlib import Path -from typing import Any, List, Mapping +from typing import Any from unittest.mock import call, patch import pytest import requests import yaml +from jsonschema.exceptions import ValidationError + from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -28,7 +32,7 @@ from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever -from jsonschema.exceptions import ValidationError + logger = logging.getLogger("airbyte") @@ -41,8 +45,7 @@ class MockManifestDeclarativeSource(ManifestDeclarativeSource): - """ - Mock test class that is needed to monkey patch how we read from various files that make up a declarative source because of how our + """Mock test class that is needed to monkey patch how we read from various files that make up a declarative source because of how our tests write configuration files during testing. It is also used to properly namespace where files get written in specific cases like when we temporarily write files like spec.yaml to the package unit_tests, which is the directory where it will be read in during the tests. @@ -1689,7 +1692,7 @@ def test_only_parent_streams_use_cache(): assert not streams[2].retriever.requester.use_cache -def _run_read(manifest: Mapping[str, Any], stream_name: str) -> List[AirbyteMessage]: +def _run_read(manifest: Mapping[str, Any], stream_name: str) -> list[AirbyteMessage]: source = ManifestDeclarativeSource(source_config=manifest) catalog = ConfiguredAirbyteCatalog( streams=[ @@ -1707,10 +1710,10 @@ def _run_read(manifest: Mapping[str, Any], stream_name: str) -> List[AirbyteMess def test_declarative_component_schema_valid_ref_links(): def load_yaml(file_path) -> Mapping[str, Any]: - with open(file_path, "r") as file: + with open(file_path) as file: return yaml.safe_load(file) - def extract_refs(data, base_path="#") -> List[str]: + def extract_refs(data, base_path="#") -> list[str]: refs = [] if isinstance(data, dict): for key, value in data.items(): @@ -1735,7 +1738,7 @@ def resolve_pointer(data: Mapping[str, Any], pointer: str) -> bool: except (KeyError, TypeError): return False - def validate_refs(yaml_file: str) -> List[str]: + def validate_refs(yaml_file: str) -> list[str]: data = load_yaml(yaml_file) refs = extract_refs(data) invalid_refs = [ref for ref in refs if not resolve_pointer(data, ref.replace("#", ""))] diff --git a/unit_tests/sources/declarative/test_types.py b/unit_tests/sources/declarative/test_types.py index 1a15dcfa..cf6eb76c 100644 --- a/unit_tests/sources/declarative/test_types.py +++ b/unit_tests/sources/declarative/test_types.py @@ -1,6 +1,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import pytest + from airbyte_cdk.sources.types import StreamSlice diff --git a/unit_tests/sources/declarative/test_yaml_declarative_source.py b/unit_tests/sources/declarative/test_yaml_declarative_source.py index fc35f5b3..767f133f 100644 --- a/unit_tests/sources/declarative/test_yaml_declarative_source.py +++ b/unit_tests/sources/declarative/test_yaml_declarative_source.py @@ -1,15 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import os import tempfile import pytest +from yaml.parser import ParserError + from airbyte_cdk.sources.declarative.parsers.custom_exceptions import UndefinedReferenceException from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource -from yaml.parser import ParserError + logger = logging.getLogger("airbyte") @@ -23,20 +26,18 @@ class MockYamlDeclarativeSource(YamlDeclarativeSource): - """ - Mock test class that is needed to monkey patch how we read from various files that make up a declarative source because of how our + """Mock test class that is needed to monkey patch how we read from various files that make up a declarative source because of how our tests write configuration files during testing. It is also used to properly namespace where files get written in specific cases like when we temporarily write files like spec.yaml to the package unit_tests, which is the directory where it will be read in during the tests. """ def _read_and_parse_yaml_file(self, path_to_yaml_file): - """ - We override the default behavior because we use tempfile to write the yaml manifest to a temporary directory which is + """We override the default behavior because we use tempfile to write the yaml manifest to a temporary directory which is not mounted during runtime which prevents pkgutil.get_data() from being able to find the yaml file needed to generate # the declarative source. For tests we use open() which supports using an absolute path. """ - with open(path_to_yaml_file, "r") as f: + with open(path_to_yaml_file) as f: config_content = f.read() parsed_config = YamlDeclarativeSource._parse(config_content) return parsed_config diff --git a/unit_tests/sources/declarative/transformations/test_add_fields.py b/unit_tests/sources/declarative/transformations/test_add_fields.py index b598e7bc..f036b64b 100644 --- a/unit_tests/sources/declarative/transformations/test_add_fields.py +++ b/unit_tests/sources/declarative/transformations/test_add_fields.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, List, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any import pytest + from airbyte_cdk.sources.declarative.transformations import AddFields from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition from airbyte_cdk.sources.types import FieldPointer @@ -161,8 +164,8 @@ ) def test_add_fields( input_record: Mapping[str, Any], - field: List[Tuple[FieldPointer, str]], - field_type: Optional[str], + field: list[tuple[FieldPointer, str]], + field_type: str | None, kwargs: Mapping[str, Any], expected: Mapping[str, Any], ): diff --git a/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py b/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py index cdf52615..e3ec025a 100644 --- a/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py +++ b/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py @@ -1,11 +1,13 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_cdk.sources.declarative.transformations.keys_to_lower_transformation import ( KeysToLowerTransformation, ) + _ANY_VALUE = -1 diff --git a/unit_tests/sources/declarative/transformations/test_remove_fields.py b/unit_tests/sources/declarative/transformations/test_remove_fields.py index 4638b7ea..f261249e 100644 --- a/unit_tests/sources/declarative/transformations/test_remove_fields.py +++ b/unit_tests/sources/declarative/transformations/test_remove_fields.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, List, Mapping +from collections.abc import Mapping +from typing import Any import pytest + from airbyte_cdk.sources.declarative.transformations import RemoveFields from airbyte_cdk.sources.types import FieldPointer @@ -161,7 +164,7 @@ ) def test_remove_fields( input_record: Mapping[str, Any], - field_pointers: List[FieldPointer], + field_pointers: list[FieldPointer], condition: str, expected: Mapping[str, Any], ): diff --git a/unit_tests/sources/embedded/test_embedded_integration.py b/unit_tests/sources/embedded/test_embedded_integration.py index f8e11cff..86d33f4d 100644 --- a/unit_tests/sources/embedded/test_embedded_integration.py +++ b/unit_tests/sources/embedded/test_embedded_integration.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import unittest -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from unittest.mock import MagicMock from airbyte_cdk.models import ( @@ -26,7 +28,7 @@ class TestIntegration(BaseEmbeddedIntegration): - def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Mapping[str, Any]: + def _handle_record(self, record: AirbyteRecordMessage, id: str | None) -> Mapping[str, Any]: return {"data": record.data, "id": id} diff --git a/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py b/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py index b05bff03..7d254ec2 100644 --- a/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py +++ b/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import unittest from datetime import datetime @@ -17,6 +18,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream + _FILE_WITH_UNKNOWN_EXTENSION = RemoteFile( uri="a.unknown_extension", last_modified=datetime.now(), file_type="csv" ) @@ -45,8 +47,7 @@ def setUp(self) -> None: def test_given_file_extension_does_not_match_when_check_availability_and_parsability_then_stream_is_still_available( self, ) -> None: - """ - Before, we had a validation on the file extension but it turns out that in production, users sometimes have mismatch there. The + """Before, we had a validation on the file extension but it turns out that in production, users sometimes have mismatch there. The example we've seen was for JSONL parser but the file extension was just `.json`. Note that there we more than one record extracted from this stream so it's not just that the file is one JSON object """ @@ -60,9 +61,7 @@ def test_given_file_extension_does_not_match_when_check_availability_and_parsabi assert is_available def test_not_available_given_no_files(self) -> None: - """ - If no files are returned, then the stream is not available. - """ + """If no files are returned, then the stream is not available.""" self._stream.get_files.return_value = [] is_available, reason = self._strategy.check_availability_and_parsability( @@ -73,9 +72,7 @@ def test_not_available_given_no_files(self) -> None: assert "No files were identified in the stream" in reason def test_parse_records_is_not_called_with_parser_max_n_files_for_parsability_set(self) -> None: - """ - If the stream parser sets parser_max_n_files_for_parsability to 0, then we should not call parse_records on it - """ + """If the stream parser sets parser_max_n_files_for_parsability to 0, then we should not call parse_records on it""" self._parser.parser_max_n_files_for_parsability = 0 self._stream.get_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION] @@ -88,9 +85,7 @@ def test_parse_records_is_not_called_with_parser_max_n_files_for_parsability_set assert self._stream_reader.open_file.called def test_passing_config_check(self) -> None: - """ - Test if the DefaultFileBasedAvailabilityStrategy correctly handles the check_config method defined on the parser. - """ + """Test if the DefaultFileBasedAvailabilityStrategy correctly handles the check_config method defined on the parser.""" self._parser.check_config.return_value = (False, "Ran into error") is_available, error_message = self._strategy.check_availability_and_parsability( self._stream, Mock(), Mock() @@ -99,8 +94,7 @@ def test_passing_config_check(self) -> None: assert "Ran into error" in error_message def test_catching_and_raising_custom_file_based_exception(self) -> None: - """ - Test if the DefaultFileBasedAvailabilityStrategy correctly handles the CustomFileBasedException + """Test if the DefaultFileBasedAvailabilityStrategy correctly handles the CustomFileBasedException by raising a CheckAvailabilityError when the get_files method is called. """ # Mock the get_files method to raise CustomFileBasedException when called diff --git a/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py b/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py index 84de3ad6..ad639b87 100644 --- a/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py +++ b/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py @@ -1,17 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import Type +from __future__ import annotations import pytest +from jsonschema import ValidationError, validate +from pydantic.v1 import BaseModel + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( AvroFormat, CsvFormat, ParquetFormat, ) -from jsonschema import ValidationError, validate -from pydantic.v1 import BaseModel @pytest.mark.parametrize( @@ -30,7 +30,7 @@ ], ) def test_parquet_file_type_is_not_a_valid_csv_file_type( - file_format: BaseModel, file_type: str, expected_error: Type[Exception] + file_format: BaseModel, file_type: str, expected_error: type[Exception] ) -> None: format_config = {file_type: {"filetype": file_type, "decimal_as_float": True}} diff --git a/unit_tests/sources/file_based/config/test_csv_format.py b/unit_tests/sources/file_based/config/test_csv_format.py index ace9e034..7c8acd61 100644 --- a/unit_tests/sources/file_based/config/test_csv_format.py +++ b/unit_tests/sources/file_based/config/test_csv_format.py @@ -1,17 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import unittest import pytest +from pydantic.v1.error_wrappers import ValidationError + from airbyte_cdk.sources.file_based.config.csv_format import ( CsvFormat, CsvHeaderAutogenerated, CsvHeaderFromCsv, CsvHeaderUserProvided, ) -from pydantic.v1.error_wrappers import ValidationError class CsvHeaderDefinitionTest(unittest.TestCase): diff --git a/unit_tests/sources/file_based/config/test_file_based_stream_config.py b/unit_tests/sources/file_based/config/test_file_based_stream_config.py index addc7223..a351699c 100644 --- a/unit_tests/sources/file_based/config/test_file_based_stream_config.py +++ b/unit_tests/sources/file_based/config/test_file_based_stream_config.py @@ -1,15 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping, Type +from collections.abc import Mapping +from typing import Any import pytest as pytest +from pydantic.v1.error_wrappers import ValidationError + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( CsvFormat, FileBasedStreamConfig, ) -from pydantic.v1.error_wrappers import ValidationError @pytest.mark.parametrize( @@ -90,7 +93,7 @@ def test_csv_config( file_type: str, input_format: Mapping[str, Any], expected_format: Mapping[str, Any], - expected_error: Type[Exception], + expected_error: type[Exception], ) -> None: stream_config = { "name": "stream1", diff --git a/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py b/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py index 8cb97715..84842fb5 100644 --- a/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py +++ b/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import unittest from unittest.mock import Mock @@ -19,15 +20,11 @@ def setUp(self) -> None: self._parser.parser_max_n_files_for_schema_inference = None def test_hardcoded_schema_inference_file_limit_is_returned(self) -> None: - """ - If the parser is not providing a limit, then we should use the hardcoded limit - """ + """If the parser is not providing a limit, then we should use the hardcoded limit""" assert self._policy.get_max_n_files_for_schema_inference(self._parser) == 10 def test_parser_limit_is_respected(self) -> None: - """ - If the parser is providing a limit, then we should use that limit - """ + """If the parser is providing a limit, then we should use that limit""" self._parser.parser_max_n_files_for_schema_inference = 1 assert self._policy.get_max_n_files_for_schema_inference(self._parser) == 1 diff --git a/unit_tests/sources/file_based/file_types/test_avro_parser.py b/unit_tests/sources/file_based/file_types/test_avro_parser.py index 2c52f9f8..279534cd 100644 --- a/unit_tests/sources/file_based/file_types/test_avro_parser.py +++ b/unit_tests/sources/file_based/file_types/test_avro_parser.py @@ -1,14 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime +import math import uuid import pytest + from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat from airbyte_cdk.sources.file_based.file_types import AvroParser + _default_avro_format = AvroFormat() _double_as_string_avro_format = AvroFormat(double_as_string=True) _uuid_value = uuid.uuid4() @@ -338,7 +342,7 @@ def test_convert_primitive_avro_type_to_json( _default_avro_format, "string", "hello world", "hello world", id="test_string" ), pytest.param( - _default_avro_format, {"logicalType": "decimal"}, 3.1415, "3.1415", id="test_decimal" + _default_avro_format, {"logicalType": "decimal"}, math.pi, "3.1415", id="test_decimal" ), pytest.param( _default_avro_format, diff --git a/unit_tests/sources/file_based/file_types/test_csv_parser.py b/unit_tests/sources/file_based/file_types/test_csv_parser.py index 295c4da6..e22eb1ae 100644 --- a/unit_tests/sources/file_based/file_types/test_csv_parser.py +++ b/unit_tests/sources/file_based/file_types/test_csv_parser.py @@ -1,18 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import asyncio import csv import io import logging import unittest +from collections.abc import Generator from datetime import datetime -from typing import Any, Dict, Generator, List, Set +from typing import Any from unittest import TestCase, mock from unittest.mock import Mock import pytest + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.csv_format import ( DEFAULT_FALSE_VALUES, @@ -32,6 +35,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.utils.traced_exception import AirbyteTracedException + PROPERTY_TYPES = { "col1": "null", "col2": "boolean", @@ -188,10 +192,10 @@ ], ) def test_cast_to_python_type( - row: Dict[str, str], - true_values: Set[str], - false_values: Set[str], - expected_output: Dict[str, Any], + row: dict[str, str], + true_values: set[str], + false_values: set[str], + expected_output: dict[str, Any], ) -> None: csv_format = CsvFormat(true_values=true_values, false_values=false_values) assert CsvParser._cast_types(row, PROPERTY_TYPES, csv_format, logger) == expected_output @@ -320,7 +324,7 @@ def test_given_empty_csv_file_when_infer_schema_then_raise_config_error(self) -> self._infer_schema() assert exception.value.failure_type == FailureType.config_error - def _test_infer_schema(self, rows: List[str], expected_type: str) -> None: + def _test_infer_schema(self, rows: list[str], expected_type: str) -> None: self._csv_reader.read_data.return_value = ({self._HEADER_NAME: row} for row in rows) inferred_schema = self._infer_schema() assert inferred_schema == {self._HEADER_NAME: {"type": expected_type}} @@ -336,14 +340,14 @@ def _infer_schema(self): class CsvFileBuilder: def __init__(self) -> None: - self._prefixed_rows: List[str] = [] - self._data: List[str] = [] + self._prefixed_rows: list[str] = [] + self._data: list[str] = [] - def with_prefixed_rows(self, rows: List[str]) -> "CsvFileBuilder": + def with_prefixed_rows(self, rows: list[str]) -> CsvFileBuilder: self._prefixed_rows = rows return self - def with_data(self, data: List[str]) -> "CsvFileBuilder": + def with_data(self, data: list[str]) -> CsvFileBuilder: self._data = data return self @@ -657,7 +661,7 @@ def test_read_data_with_encoding_error(self) -> None: assert "encoding" in ate.value.message assert self._csv_reader._get_headers.called - def _read_data(self) -> Generator[Dict[str, str], None, None]: + def _read_data(self) -> Generator[dict[str, str], None, None]: data_generator = self._csv_reader.read_data( self._config, self._file, diff --git a/unit_tests/sources/file_based/file_types/test_excel_parser.py b/unit_tests/sources/file_based/file_types/test_excel_parser.py index aac74be9..56989208 100644 --- a/unit_tests/sources/file_based/file_types/test_excel_parser.py +++ b/unit_tests/sources/file_based/file_types/test_excel_parser.py @@ -1,7 +1,7 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import datetime from io import BytesIO @@ -9,6 +9,7 @@ import pandas as pd import pytest + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( ExcelFormat, FileBasedStreamConfig, diff --git a/unit_tests/sources/file_based/file_types/test_jsonl_parser.py b/unit_tests/sources/file_based/file_types/test_jsonl_parser.py index cf924131..af5e57fb 100644 --- a/unit_tests/sources/file_based/file_types/test_jsonl_parser.py +++ b/unit_tests/sources/file_based/file_types/test_jsonl_parser.py @@ -1,18 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import asyncio import io import json -from typing import Any, Dict +from typing import Any from unittest.mock import MagicMock, Mock import pytest + from airbyte_cdk.sources.file_based.exceptions import RecordParseError from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.file_types import JsonlParser + JSONL_CONTENT_WITHOUT_MULTILINE_JSON_OBJECTS = [ b'{"a": 1, "b": "1"}', b'{"a": 2, "b": "2"}', @@ -43,7 +46,7 @@ def stream_reader() -> MagicMock: return MagicMock(spec=AbstractFileBasedStreamReader) -def _infer_schema(stream_reader: MagicMock) -> Dict[str, Any]: +def _infer_schema(stream_reader: MagicMock) -> dict[str, Any]: loop = asyncio.new_event_loop() task = loop.create_task(JsonlParser().infer_schema(Mock(), Mock(), stream_reader, Mock())) loop.run_until_complete(task) @@ -86,13 +89,13 @@ def test_given_str_io_when_infer_then_return_proper_types(stream_reader: MagicMo def test_given_empty_record_when_infer_then_return_empty_schema(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO("{}".encode("utf-8")) + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO(b"{}") schema = _infer_schema(stream_reader) assert schema == {} def test_given_no_records_when_infer_then_return_empty_schema(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO("".encode("utf-8")) + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO(b"") schema = _infer_schema(stream_reader) assert schema == {} @@ -139,7 +142,7 @@ def test_given_multiline_json_objects_and_hits_read_limit_when_infer_then_return def test_given_multiple_records_then_merge_types(stream_reader: MagicMock) -> None: stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO( - '{"col1": 1}\n{"col1": 2.3}'.encode("utf-8") + b'{"col1": 1}\n{"col1": 2.3}' ) schema = _infer_schema(stream_reader) assert schema == {"col1": {"type": "number"}} diff --git a/unit_tests/sources/file_based/file_types/test_parquet_parser.py b/unit_tests/sources/file_based/file_types/test_parquet_parser.py index e0c06e86..36ca1084 100644 --- a/unit_tests/sources/file_based/file_types/test_parquet_parser.py +++ b/unit_tests/sources/file_based/file_types/test_parquet_parser.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import asyncio import datetime import math -from typing import Any, Mapping, Union +from collections.abc import Mapping +from typing import Any from unittest.mock import Mock import pyarrow as pa import pytest +from pyarrow import Scalar + from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( FileBasedStreamConfig, @@ -19,7 +23,7 @@ from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat from airbyte_cdk.sources.file_based.file_types import ParquetParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from pyarrow import Scalar + _default_parquet_format = ParquetFormat() _decimal_as_float_parquet_format = ParquetFormat(decimal_as_float=True) @@ -220,7 +224,9 @@ def test_type_mapping( pytest.param(pa.uint32(), _default_parquet_format, 6, 6, id="test_parquet_uint32"), pytest.param(pa.uint64(), _default_parquet_format, 6, 6, id="test_parquet_uint64"), pytest.param(pa.float32(), _default_parquet_format, 2.7, 2.7, id="test_parquet_float32"), - pytest.param(pa.float64(), _default_parquet_format, 3.14, 3.14, id="test_parquet_float64"), + pytest.param( + pa.float64(), _default_parquet_format, math.pi, math.pi, id="test_parquet_float64" + ), pytest.param( pa.time32("s"), _default_parquet_format, @@ -325,9 +331,7 @@ def test_type_mapping( "this is a string", id="test_parquet_string", ), - pytest.param( - pa.utf8(), _default_parquet_format, "utf8".encode("utf8"), "utf8", id="test_utf8" - ), + pytest.param(pa.utf8(), _default_parquet_format, b"utf8", "utf8", id="test_utf8"), pytest.param( pa.large_binary(), _default_parquet_format, @@ -501,7 +505,7 @@ def test_null_value_does_not_throw(parquet_type, parquet_format) -> None: pytest.param(JsonlFormat(), id="test_jsonl_format"), ], ) -def test_wrong_file_format(file_format: Union[CsvFormat, JsonlFormat]) -> None: +def test_wrong_file_format(file_format: CsvFormat | JsonlFormat) -> None: parser = ParquetParser() config = FileBasedStreamConfig( name="test.parquet", diff --git a/unit_tests/sources/file_based/file_types/test_unstructured_parser.py b/unit_tests/sources/file_based/file_types/test_unstructured_parser.py index ea4e091a..783faea6 100644 --- a/unit_tests/sources/file_based/file_types/test_unstructured_parser.py +++ b/unit_tests/sources/file_based/file_types/test_unstructured_parser.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import asyncio from datetime import datetime @@ -9,6 +10,9 @@ import pytest import requests +from unstructured.documents.elements import ElementMetadata, Formula, ListItem, Text, Title +from unstructured.file_utils.filetype import FileType + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.config.unstructured_format import ( @@ -20,8 +24,7 @@ from airbyte_cdk.sources.file_based.file_types import UnstructuredParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unstructured.documents.elements import ElementMetadata, Formula, ListItem, Text, Title -from unstructured.file_utils.filetype import FileType + FILE_URI = "path/to/file.xyz" diff --git a/unit_tests/sources/file_based/helpers.py b/unit_tests/sources/file_based/helpers.py index 2138cdc5..7a563d37 100644 --- a/unit_tests/sources/file_based/helpers.py +++ b/unit_tests/sources/file_based/helpers.py @@ -1,11 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Mapping from datetime import datetime from io import IOBase -from typing import Any, Dict, List, Mapping, Optional +from typing import Any + +from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesStreamReader from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy @@ -20,7 +24,6 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedConcurrentCursor from airbyte_cdk.sources.file_based.stream.cursor import DefaultFileBasedCursor -from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesStreamReader class EmptySchemaParser(CsvParser): @@ -30,7 +33,7 @@ async def infer_schema( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return {} @@ -46,9 +49,9 @@ class LowInferenceBytesJsonlParser(JsonlParser): class TestErrorListMatchingFilesInMemoryFilesStreamReader(InMemoryFilesStreamReader): def get_matching_files( self, - globs: List[str], - from_date: Optional[datetime] = None, - ) -> List[RemoteFile]: + globs: list[str], + from_date: datetime | None = None, + ) -> list[RemoteFile]: raise Exception("Error listing files") @@ -57,7 +60,7 @@ def open_file( self, file: RemoteFile, file_read_mode: FileReadMode, - encoding: Optional[str], + encoding: str | None, logger: logging.Logger, ) -> IOBase: raise Exception("Error opening file") @@ -68,7 +71,7 @@ class FailingSchemaValidationPolicy(AbstractSchemaValidationPolicy): validate_schema_before_sync = True def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: return False @@ -81,7 +84,7 @@ class LowHistoryLimitConcurrentCursor(FileBasedConcurrentCursor): DEFAULT_MAX_HISTORY_SIZE = 3 -def make_remote_files(files: List[str]) -> List[RemoteFile]: +def make_remote_files(files: list[str]) -> list[RemoteFile]: return [ RemoteFile( uri=f, diff --git a/unit_tests/sources/file_based/in_memory_files_source.py b/unit_tests/sources/file_based/in_memory_files_source.py index 1a6ef55b..ee34b443 100644 --- a/unit_tests/sources/file_based/in_memory_files_source.py +++ b/unit_tests/sources/file_based/in_memory_files_source.py @@ -1,21 +1,26 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import csv import io import json import logging import tempfile +from collections.abc import Iterable, Mapping from datetime import datetime from io import IOBase -from typing import Any, Dict, Iterable, List, Mapping, Optional +from typing import Any import avro.io as ai import avro.schema as avro_schema import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from avro import datafile +from pydantic.v1 import AnyUrl + from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer from airbyte_cdk.sources.file_based.availability_strategy import ( AbstractFileBasedAvailabilityStrategy, @@ -42,8 +47,6 @@ DefaultFileBasedCursor, ) from airbyte_cdk.sources.source import TState -from avro import datafile -from pydantic.v1 import AnyUrl class InMemoryFilesSource(FileBasedSource): @@ -53,16 +56,16 @@ def __init__( self, files: Mapping[str, Any], file_type: str, - availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy], - discovery_policy: Optional[AbstractDiscoveryPolicy], + availability_strategy: AbstractFileBasedAvailabilityStrategy | None, + discovery_policy: AbstractDiscoveryPolicy | None, validation_policies: Mapping[str, AbstractSchemaValidationPolicy], parsers: Mapping[str, FileTypeParser], - stream_reader: Optional[AbstractFileBasedStreamReader], - catalog: Optional[Mapping[str, Any]], - config: Optional[Mapping[str, Any]], - state: Optional[TState], + stream_reader: AbstractFileBasedStreamReader | None, + catalog: Mapping[str, Any] | None, + config: Mapping[str, Any] | None, + state: TState | None, file_write_options: Mapping[str, Any], - cursor_cls: Optional[AbstractFileBasedCursor], + cursor_cls: AbstractFileBasedCursor | None, ): # Attributes required for test purposes self.files = files @@ -103,7 +106,7 @@ def __init__( self, files: Mapping[str, Mapping[str, Any]], file_type: str, - file_write_options: Optional[Mapping[str, Any]] = None, + file_write_options: Mapping[str, Any] | None = None, ): self.files = files self.file_type = file_type @@ -111,7 +114,7 @@ def __init__( super().__init__() @property - def config(self) -> Optional[AbstractFileBasedSpec]: + def config(self) -> AbstractFileBasedSpec | None: return self._config @config.setter @@ -120,8 +123,8 @@ def config(self, value: AbstractFileBasedSpec) -> None: def get_matching_files( self, - globs: List[str], - prefix: Optional[str], + globs: list[str], + prefix: str | None, logger: logging.Logger, ) -> Iterable[RemoteFile]: yield from self.filter_files_by_globs_and_start_date( @@ -141,20 +144,19 @@ def file_size(self, file: RemoteFile) -> int: def get_file( self, file: RemoteFile, local_directory: str, logger: logging.Logger - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return {} def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: if self.file_type == "csv": return self._make_csv_file_contents(file.uri) - elif self.file_type == "jsonl": + if self.file_type == "jsonl": return self._make_jsonl_file_contents(file.uri) - elif self.file_type == "unstructured": + if self.file_type == "unstructured": return self._make_binary_file_contents(file.uri) - else: - raise NotImplementedError(f"No implementation for file type: {self.file_type}") + raise NotImplementedError(f"No implementation for file type: {self.file_type}") def _make_csv_file_contents(self, file_name: str) -> IOBase: # Some tests define the csv as an array of strings to make it easier to validate the handling @@ -204,12 +206,10 @@ def documentation_url(cls) -> AnyUrl: class TemporaryParquetFilesStreamReader(InMemoryFilesStreamReader): - """ - A file reader that writes RemoteFiles to a temporary file and then reads them back. - """ + """A file reader that writes RemoteFiles to a temporary file and then reads them back.""" def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: return io.BytesIO(self._create_file(file.uri)) @@ -227,12 +227,10 @@ def _create_file(self, file_name: str) -> bytes: class TemporaryAvroFilesStreamReader(InMemoryFilesStreamReader): - """ - A file reader that writes RemoteFiles to a temporary file and then reads them back. - """ + """A file reader that writes RemoteFiles to a temporary file and then reads them back.""" def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: return io.BytesIO(self._make_file_contents(file.uri)) @@ -253,12 +251,10 @@ def _make_file_contents(self, file_name: str) -> bytes: class TemporaryExcelFilesStreamReader(InMemoryFilesStreamReader): - """ - A file reader that writes RemoteFiles to a temporary file and then reads them back. - """ + """A file reader that writes RemoteFiles to a temporary file and then reads them back.""" def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: return io.BytesIO(self._make_file_contents(file.uri)) diff --git a/unit_tests/sources/file_based/scenarios/avro_scenarios.py b/unit_tests/sources/file_based/scenarios/avro_scenarios.py index 77f51c68..e2fff57d 100644 --- a/unit_tests/sources/file_based/scenarios/avro_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/avro_scenarios.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import decimal @@ -9,6 +10,7 @@ from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder + _single_avro_file = { "a.avro": { "schema": { diff --git a/unit_tests/sources/file_based/scenarios/check_scenarios.py b/unit_tests/sources/file_based/scenarios/check_scenarios.py index 9a235b9e..6e579569 100644 --- a/unit_tests/sources/file_based/scenarios/check_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/check_scenarios.py @@ -1,8 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError from unit_tests.sources.file_based.helpers import ( FailingSchemaValidationPolicy, TestErrorListMatchingFilesInMemoryFilesStreamReader, @@ -11,6 +11,9 @@ from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError + + _base_success_scenario = ( TestScenarioBuilder() .set_config( diff --git a/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py b/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py index 92ce67fe..125b6568 100644 --- a/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py @@ -1,9 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedConcurrentCursor -from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.file_based.helpers import LowHistoryLimitConcurrentCursor from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import ( @@ -11,6 +10,10 @@ TestScenarioBuilder, ) +from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedConcurrentCursor +from airbyte_cdk.test.state_builder import StateBuilder + + single_csv_input_state_is_earlier_scenario_concurrent = ( TestScenarioBuilder() .set_name("single_csv_input_state_is_earlier_concurrent") diff --git a/unit_tests/sources/file_based/scenarios/csv_scenarios.py b/unit_tests/sources/file_based/scenarios/csv_scenarios.py index 2f4f02cf..78115e4f 100644 --- a/unit_tests/sources/file_based/scenarios/csv_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/csv_scenarios.py @@ -1,12 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode -from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat -from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError -from airbyte_cdk.test.catalog_builder import CatalogBuilder -from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.helpers import ( EmptySchemaParser, LowInferenceLimitDiscoveryPolicy, @@ -18,6 +14,13 @@ TestScenarioBuilder, ) +from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode +from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat +from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError +from airbyte_cdk.test.catalog_builder import CatalogBuilder +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + single_csv_scenario: TestScenario[InMemoryFilesSource] = ( TestScenarioBuilder[InMemoryFilesSource]() .set_name("single_csv_scenario") diff --git a/unit_tests/sources/file_based/scenarios/excel_scenarios.py b/unit_tests/sources/file_based/scenarios/excel_scenarios.py index 94ccc676..a5343c0e 100644 --- a/unit_tests/sources/file_based/scenarios/excel_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/excel_scenarios.py @@ -1,6 +1,7 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime @@ -8,6 +9,7 @@ from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder + _single_excel_file = { "a.xlsx": { "contents": [ diff --git a/unit_tests/sources/file_based/scenarios/file_based_source_builder.py b/unit_tests/sources/file_based/scenarios/file_based_source_builder.py index 4c2939f6..2f259040 100644 --- a/unit_tests/sources/file_based/scenarios/file_based_source_builder.py +++ b/unit_tests/sources/file_based/scenarios/file_based_source_builder.py @@ -1,9 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from copy import deepcopy -from typing import Any, Mapping, Optional, Type +from typing import Any + +from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource +from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder from airbyte_cdk.sources.file_based.availability_strategy.abstract_file_based_availability_strategy import ( AbstractFileBasedAvailabilityStrategy, @@ -18,29 +23,27 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor from airbyte_cdk.sources.source import TState -from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]): def __init__(self) -> None: self._files: Mapping[str, Any] = {} - self._file_type: Optional[str] = None - self._availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None + self._file_type: str | None = None + self._availability_strategy: AbstractFileBasedAvailabilityStrategy | None = None self._discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy() - self._validation_policies: Optional[Mapping[str, AbstractSchemaValidationPolicy]] = None + self._validation_policies: Mapping[str, AbstractSchemaValidationPolicy] | None = None self._parsers = default_parsers - self._stream_reader: Optional[AbstractFileBasedStreamReader] = None + self._stream_reader: AbstractFileBasedStreamReader | None = None self._file_write_options: Mapping[str, Any] = {} - self._cursor_cls: Optional[Type[AbstractFileBasedCursor]] = None - self._config: Optional[Mapping[str, Any]] = None - self._state: Optional[TState] = None + self._cursor_cls: type[AbstractFileBasedCursor] | None = None + self._config: Mapping[str, Any] | None = None + self._state: TState | None = None def build( self, - configured_catalog: Optional[Mapping[str, Any]], - config: Optional[Mapping[str, Any]], - state: Optional[TState], + configured_catalog: Mapping[str, Any] | None, + config: Mapping[str, Any] | None, + state: TState | None, ) -> InMemoryFilesSource: if self._file_type is None: raise ValueError("file_type is not set") @@ -59,51 +62,51 @@ def build( self._cursor_cls, ) - def set_files(self, files: Mapping[str, Any]) -> "FileBasedSourceBuilder": + def set_files(self, files: Mapping[str, Any]) -> FileBasedSourceBuilder: self._files = files return self - def set_file_type(self, file_type: str) -> "FileBasedSourceBuilder": + def set_file_type(self, file_type: str) -> FileBasedSourceBuilder: self._file_type = file_type return self - def set_parsers(self, parsers: Mapping[Type[Any], FileTypeParser]) -> "FileBasedSourceBuilder": + def set_parsers(self, parsers: Mapping[type[Any], FileTypeParser]) -> FileBasedSourceBuilder: self._parsers = parsers return self def set_availability_strategy( self, availability_strategy: AbstractFileBasedAvailabilityStrategy - ) -> "FileBasedSourceBuilder": + ) -> FileBasedSourceBuilder: self._availability_strategy = availability_strategy return self def set_discovery_policy( self, discovery_policy: AbstractDiscoveryPolicy - ) -> "FileBasedSourceBuilder": + ) -> FileBasedSourceBuilder: self._discovery_policy = discovery_policy return self def set_validation_policies( self, validation_policies: Mapping[str, AbstractSchemaValidationPolicy] - ) -> "FileBasedSourceBuilder": + ) -> FileBasedSourceBuilder: self._validation_policies = validation_policies return self def set_stream_reader( self, stream_reader: AbstractFileBasedStreamReader - ) -> "FileBasedSourceBuilder": + ) -> FileBasedSourceBuilder: self._stream_reader = stream_reader return self - def set_cursor_cls(self, cursor_cls: AbstractFileBasedCursor) -> "FileBasedSourceBuilder": + def set_cursor_cls(self, cursor_cls: AbstractFileBasedCursor) -> FileBasedSourceBuilder: self._cursor_cls = cursor_cls return self def set_file_write_options( self, file_write_options: Mapping[str, Any] - ) -> "FileBasedSourceBuilder": + ) -> FileBasedSourceBuilder: self._file_write_options = file_write_options return self - def copy(self) -> "FileBasedSourceBuilder": + def copy(self) -> FileBasedSourceBuilder: return deepcopy(self) diff --git a/unit_tests/sources/file_based/scenarios/incremental_scenarios.py b/unit_tests/sources/file_based/scenarios/incremental_scenarios.py index aea4b484..e439817a 100644 --- a/unit_tests/sources/file_based/scenarios/incremental_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/incremental_scenarios.py @@ -1,9 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.sources.file_based.stream.cursor import DefaultFileBasedCursor -from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.file_based.helpers import LowHistoryLimitCursor from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import ( @@ -11,6 +10,10 @@ TestScenarioBuilder, ) +from airbyte_cdk.sources.file_based.stream.cursor import DefaultFileBasedCursor +from airbyte_cdk.test.state_builder import StateBuilder + + single_csv_input_state_is_earlier_scenario = ( TestScenarioBuilder() .set_name("single_csv_input_state_is_earlier") diff --git a/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py b/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py index c4ebafca..638ca5a9 100644 --- a/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py @@ -1,10 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat -from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError -from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.helpers import ( LowInferenceBytesJsonlParser, LowInferenceLimitDiscoveryPolicy, @@ -12,6 +10,11 @@ from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat +from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + single_jsonl_scenario = ( TestScenarioBuilder() .set_name("single_jsonl_scenario") diff --git a/unit_tests/sources/file_based/scenarios/parquet_scenarios.py b/unit_tests/sources/file_based/scenarios/parquet_scenarios.py index 5ddb8468..1926737f 100644 --- a/unit_tests/sources/file_based/scenarios/parquet_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/parquet_scenarios.py @@ -1,16 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import datetime import decimal +import math import pyarrow as pa -from airbyte_cdk.utils.traced_exception import AirbyteTracedException + from unit_tests.sources.file_based.in_memory_files_source import TemporaryParquetFilesStreamReader from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + _single_parquet_file = { "a.parquet": { "contents": [ @@ -103,7 +108,7 @@ 3, 4, 5, - 3.14, + math.pi, 5.0, "2020-01-01", datetime.date(2021, 1, 1), @@ -506,7 +511,7 @@ "col_uint16": 3, "col_uint32": 4, "col_uint64": 5, - "col_float32": 3.14, + "col_float32": math.pi, "col_float64": 5.0, "col_string": "2020-01-01", "col_date32": "2021-01-01", diff --git a/unit_tests/sources/file_based/scenarios/scenario_builder.py b/unit_tests/sources/file_based/scenarios/scenario_builder.py index da8c7ba8..b15a56fc 100644 --- a/unit_tests/sources/file_based/scenarios/scenario_builder.py +++ b/unit_tests/sources/file_based/scenarios/scenario_builder.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from abc import ABC, abstractmethod +from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Generic, List, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import Any, Generic, TypeVar from airbyte_cdk.models import ( AirbyteAnalyticsTraceMessage, @@ -19,26 +22,24 @@ @dataclass class IncrementalScenarioConfig: - input_state: List[Mapping[str, Any]] = field(default_factory=list) - expected_output_state: Optional[Mapping[str, Any]] = None + input_state: list[Mapping[str, Any]] = field(default_factory=list) + expected_output_state: Mapping[str, Any] | None = None SourceType = TypeVar("SourceType", bound=AbstractSource) class SourceBuilder(ABC, Generic[SourceType]): - """ - A builder that creates a source instance of type SourceType - """ + """A builder that creates a source instance of type SourceType""" @abstractmethod def build( self, - configured_catalog: Optional[Mapping[str, Any]], - config: Optional[Mapping[str, Any]], - state: Optional[TState], + configured_catalog: Mapping[str, Any] | None, + config: Mapping[str, Any] | None, + state: TState | None, ) -> SourceType: - raise NotImplementedError() + raise NotImplementedError class TestScenario(Generic[SourceType]): @@ -47,18 +48,18 @@ def __init__( name: str, config: Mapping[str, Any], source: SourceType, - expected_spec: Optional[Mapping[str, Any]], - expected_check_status: Optional[str], - expected_catalog: Optional[Mapping[str, Any]], - expected_logs: Optional[Mapping[str, List[Mapping[str, Any]]]], - expected_records: List[Mapping[str, Any]], - expected_check_error: Tuple[Optional[Type[Exception]], Optional[str]], - expected_discover_error: Tuple[Optional[Type[Exception]], Optional[str]], - expected_read_error: Tuple[Optional[Type[Exception]], Optional[str]], - incremental_scenario_config: Optional[IncrementalScenarioConfig], - expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] = None, - log_levels: Optional[Set[str]] = None, - catalog: Optional[ConfiguredAirbyteCatalog] = None, + expected_spec: Mapping[str, Any] | None, + expected_check_status: str | None, + expected_catalog: Mapping[str, Any] | None, + expected_logs: Mapping[str, list[Mapping[str, Any]]] | None, + expected_records: list[Mapping[str, Any]], + expected_check_error: tuple[type[Exception] | None, str | None], + expected_discover_error: tuple[type[Exception] | None, str | None], + expected_read_error: tuple[type[Exception] | None, str | None], + incremental_scenario_config: IncrementalScenarioConfig | None, + expected_analytics: list[AirbyteAnalyticsTraceMessage] | None = None, + log_levels: set[str] | None = None, + catalog: ConfiguredAirbyteCatalog | None = None, ): if log_levels is None: log_levels = {"ERROR", "WARN", "WARNING"} @@ -82,7 +83,7 @@ def __init__( def validate(self) -> None: assert self.name - def configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]]: + def configured_catalog(self, sync_mode: SyncMode) -> Mapping[str, Any] | None: # The preferred way of returning the catalog for the TestScenario is by providing it at the initialization. The previous solution # relied on `self.source.streams` which might raise an exception hence screwing the tests results as the user might expect the # exception to be raised as part of the actual check/discover/read commands @@ -106,121 +107,118 @@ def configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]] return catalog - def input_state(self) -> List[Mapping[str, Any]]: + def input_state(self) -> list[Mapping[str, Any]]: if self.incremental_scenario_config: return self.incremental_scenario_config.input_state - else: - return [] + return [] class TestScenarioBuilder(Generic[SourceType]): - """ - A builder that creates a TestScenario instance for a source of type SourceType - """ + """A builder that creates a TestScenario instance for a source of type SourceType""" def __init__(self) -> None: self._name = "" self._config: Mapping[str, Any] = {} - self._catalog: Optional[ConfiguredAirbyteCatalog] = None - self._expected_spec: Optional[Mapping[str, Any]] = None - self._expected_check_status: Optional[str] = None + self._catalog: ConfiguredAirbyteCatalog | None = None + self._expected_spec: Mapping[str, Any] | None = None + self._expected_check_status: str | None = None self._expected_catalog: Mapping[str, Any] = {} - self._expected_logs: Optional[Mapping[str, Any]] = None - self._expected_records: List[Mapping[str, Any]] = [] - self._expected_check_error: Tuple[Optional[Type[Exception]], Optional[str]] = None, None - self._expected_discover_error: Tuple[Optional[Type[Exception]], Optional[str]] = None, None - self._expected_read_error: Tuple[Optional[Type[Exception]], Optional[str]] = None, None - self._incremental_scenario_config: Optional[IncrementalScenarioConfig] = None - self._expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] = None - self.source_builder: Optional[SourceBuilder[SourceType]] = None + self._expected_logs: Mapping[str, Any] | None = None + self._expected_records: list[Mapping[str, Any]] = [] + self._expected_check_error: tuple[type[Exception] | None, str | None] = None, None + self._expected_discover_error: tuple[type[Exception] | None, str | None] = None, None + self._expected_read_error: tuple[type[Exception] | None, str | None] = None, None + self._incremental_scenario_config: IncrementalScenarioConfig | None = None + self._expected_analytics: list[AirbyteAnalyticsTraceMessage] | None = None + self.source_builder: SourceBuilder[SourceType] | None = None self._log_levels = None - def set_name(self, name: str) -> "TestScenarioBuilder[SourceType]": + def set_name(self, name: str) -> TestScenarioBuilder[SourceType]: self._name = name return self - def set_config(self, config: Mapping[str, Any]) -> "TestScenarioBuilder[SourceType]": + def set_config(self, config: Mapping[str, Any]) -> TestScenarioBuilder[SourceType]: self._config = config return self def set_expected_spec( self, expected_spec: Mapping[str, Any] - ) -> "TestScenarioBuilder[SourceType]": + ) -> TestScenarioBuilder[SourceType]: self._expected_spec = expected_spec return self - def set_catalog(self, catalog: ConfiguredAirbyteCatalog) -> "TestScenarioBuilder[SourceType]": + def set_catalog(self, catalog: ConfiguredAirbyteCatalog) -> TestScenarioBuilder[SourceType]: self._catalog = catalog return self def set_expected_check_status( self, expected_check_status: str - ) -> "TestScenarioBuilder[SourceType]": + ) -> TestScenarioBuilder[SourceType]: self._expected_check_status = expected_check_status return self def set_expected_catalog( self, expected_catalog: Mapping[str, Any] - ) -> "TestScenarioBuilder[SourceType]": + ) -> TestScenarioBuilder[SourceType]: self._expected_catalog = expected_catalog return self def set_expected_logs( - self, expected_logs: Mapping[str, List[Mapping[str, Any]]] - ) -> "TestScenarioBuilder[SourceType]": + self, expected_logs: Mapping[str, list[Mapping[str, Any]]] + ) -> TestScenarioBuilder[SourceType]: self._expected_logs = expected_logs return self def set_expected_records( - self, expected_records: Optional[List[Mapping[str, Any]]] - ) -> "TestScenarioBuilder[SourceType]": + self, expected_records: list[Mapping[str, Any]] | None + ) -> TestScenarioBuilder[SourceType]: self._expected_records = expected_records return self def set_incremental_scenario_config( self, incremental_scenario_config: IncrementalScenarioConfig - ) -> "TestScenarioBuilder[SourceType]": + ) -> TestScenarioBuilder[SourceType]: self._incremental_scenario_config = incremental_scenario_config return self def set_expected_check_error( - self, error: Optional[Type[Exception]], message: str - ) -> "TestScenarioBuilder[SourceType]": + self, error: type[Exception] | None, message: str + ) -> TestScenarioBuilder[SourceType]: self._expected_check_error = error, message return self def set_expected_discover_error( - self, error: Type[Exception], message: str - ) -> "TestScenarioBuilder[SourceType]": + self, error: type[Exception], message: str + ) -> TestScenarioBuilder[SourceType]: self._expected_discover_error = error, message return self def set_expected_read_error( - self, error: Type[Exception], message: str - ) -> "TestScenarioBuilder[SourceType]": + self, error: type[Exception], message: str + ) -> TestScenarioBuilder[SourceType]: self._expected_read_error = error, message return self - def set_log_levels(self, levels: Set[str]) -> "TestScenarioBuilder": + def set_log_levels(self, levels: set[str]) -> TestScenarioBuilder: self._log_levels = levels return self def set_source_builder( self, source_builder: SourceBuilder[SourceType] - ) -> "TestScenarioBuilder[SourceType]": + ) -> TestScenarioBuilder[SourceType]: self.source_builder = source_builder return self def set_expected_analytics( - self, expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] - ) -> "TestScenarioBuilder[SourceType]": + self, expected_analytics: list[AirbyteAnalyticsTraceMessage] | None + ) -> TestScenarioBuilder[SourceType]: self._expected_analytics = expected_analytics return self - def copy(self) -> "TestScenarioBuilder[SourceType]": + def copy(self) -> TestScenarioBuilder[SourceType]: return deepcopy(self) - def build(self) -> "TestScenario[SourceType]": + def build(self) -> TestScenario[SourceType]: if self.source_builder is None: raise ValueError("source_builder is not set") if self._incremental_scenario_config and self._incremental_scenario_config.input_state: @@ -255,7 +253,7 @@ def build(self) -> "TestScenario[SourceType]": self._catalog, ) - def _configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]]: + def _configured_catalog(self, sync_mode: SyncMode) -> Mapping[str, Any] | None: if not self._expected_catalog: return None catalog: Mapping[str, Any] = {"streams": []} diff --git a/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py b/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py index c8d3dae9..cb3fedb5 100644 --- a/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py @@ -1,13 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import base64 import nltk -from airbyte_cdk.utils.traced_exception import AirbyteTracedException + from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + # import nltk data for pdf parser nltk.download("punkt") nltk.download("averaged_perceptron_tagger") diff --git a/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py b/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py index 9d233921..8f179d94 100644 --- a/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py @@ -1,13 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError from airbyte_cdk.test.catalog_builder import CatalogBuilder -from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder + """ User input schema rules: diff --git a/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py b/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py index d6ed7e9a..dba5fbe7 100644 --- a/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations - -from airbyte_cdk.utils.traced_exception import AirbyteTracedException from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + _base_single_stream_scenario = ( TestScenarioBuilder() .set_source_builder( diff --git a/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py b/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py index 9cbf33e5..a42e4179 100644 --- a/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py +++ b/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py @@ -1,16 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import pytest + from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy from airbyte_cdk.sources.file_based.exceptions import StopSyncPerValidationPolicy from airbyte_cdk.sources.file_based.schema_validation_policies import ( DEFAULT_SCHEMA_VALIDATION_POLICIES, ) + CONFORMING_RECORD = { "col1": "val1", "col2": 1, diff --git a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py index 3c271dfe..4d42dc8c 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import unittest from datetime import datetime from unittest.mock import MagicMock, Mock import pytest +from freezegun import freeze_time + from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.file_based.availability_strategy import ( @@ -32,7 +36,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.record import Record from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer -from freezegun import freeze_time + _ANY_SYNC_MODE = SyncMode.full_refresh _ANY_STATE = {"state_key": "state_value"} diff --git a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py index ce48f845..179c7a3d 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py @@ -1,24 +1,27 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations - +from collections.abc import MutableMapping from datetime import datetime -from typing import Any, Dict, List, MutableMapping, Optional, Tuple +from typing import Any from unittest.mock import MagicMock import pytest +from freezegun import freeze_time + from airbyte_cdk.models import AirbyteStateMessage, SyncMode from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedConcurrentCursor from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from freezegun import freeze_time + DATE_TIME_FORMAT = FileBasedConcurrentCursor.DATE_TIME_FORMAT MOCK_DAYS_TO_SYNC_IF_HISTORY_IS_FULL = 3 -def _make_cursor(input_state: Optional[MutableMapping[str, Any]]) -> FileBasedConcurrentCursor: +def _make_cursor(input_state: MutableMapping[str, Any] | None) -> FileBasedConcurrentCursor: stream = MagicMock() stream.name = "test" stream.namespace = None @@ -100,7 +103,7 @@ def _make_cursor(input_state: Optional[MutableMapping[str, Any]]) -> FileBasedCo ], ) def test_compute_prev_sync_cursor( - input_state: MutableMapping[str, Any], expected_cursor_value: Tuple[datetime, str] + input_state: MutableMapping[str, Any], expected_cursor_value: tuple[datetime, str] ): cursor = _make_cursor(input_state) assert cursor._compute_prev_sync_cursor(input_state) == expected_cursor_value @@ -188,10 +191,10 @@ def test_compute_prev_sync_cursor( ) def test_add_file( initial_state: MutableMapping[str, Any], - pending_files: List[Tuple[str, str]], - file_to_add: Tuple[str, str], - expected_history: Dict[str, Any], - expected_pending_files: List[Tuple[str, str]], + pending_files: list[tuple[str, str]], + file_to_add: tuple[str, str], + expected_history: dict[str, Any], + expected_pending_files: list[tuple[str, str]], expected_cursor_value: str, ): cursor = _make_cursor(initial_state) @@ -262,10 +265,10 @@ def test_add_file( ) def test_add_file_invalid( initial_state: MutableMapping[str, Any], - pending_files: List[Tuple[str, str]], - file_to_add: Tuple[str, str], - expected_history: Dict[str, Any], - expected_pending_files: List[Tuple[str, str]], + pending_files: list[tuple[str, str]], + file_to_add: tuple[str, str], + expected_history: dict[str, Any], + expected_pending_files: list[tuple[str, str]], expected_cursor_value: str, ): cursor = _make_cursor(initial_state) @@ -328,7 +331,7 @@ def test_add_file_invalid( ) def test_get_new_cursor_value( input_state: MutableMapping[str, Any], - pending_files: List[Tuple[str, str]], + pending_files: list[tuple[str, str]], expected_cursor_value: str, ): cursor = _make_cursor(input_state) @@ -534,9 +537,9 @@ def test_get_files_to_sync( ) def test_should_sync_file( file_to_check: RemoteFile, - history: Dict[str, Any], + history: dict[str, Any], is_history_full: bool, - prev_cursor_value: Tuple[datetime, str], + prev_cursor_value: tuple[datetime, str], sync_start: datetime, expected_should_sync: bool, ): diff --git a/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py b/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py index 6cd3e20b..f4d21228 100644 --- a/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py +++ b/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py @@ -1,12 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from datetime import datetime, timedelta -from typing import Any, List, Mapping +from typing import Any from unittest.mock import MagicMock import pytest +from freezegun import freeze_time + from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( FileBasedStreamConfig, @@ -16,7 +20,6 @@ from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import ( DefaultFileBasedCursor, ) -from freezegun import freeze_time @pytest.mark.parametrize( @@ -162,8 +165,8 @@ ], ) def test_add_file( - files_to_add: List[RemoteFile], - expected_start_time: List[datetime], + files_to_add: list[RemoteFile], + expected_start_time: list[datetime], expected_state_dict: Mapping[str, Any], ) -> None: cursor = get_cursor(max_history_size=3, days_to_sync_if_history_is_full=3) @@ -283,8 +286,8 @@ def test_add_file( ], ) def test_get_files_to_sync( - files: List[RemoteFile], - expected_files_to_sync: List[RemoteFile], + files: list[RemoteFile], + expected_files_to_sync: list[RemoteFile], max_history_size: int, history_is_partial: bool, ) -> None: diff --git a/unit_tests/sources/file_based/stream/test_default_file_based_stream.py b/unit_tests/sources/file_based/stream/test_default_file_based_stream.py index 8eea01d6..1a869768 100644 --- a/unit_tests/sources/file_based/stream/test_default_file_based_stream.py +++ b/unit_tests/sources/file_based/stream/test_default_file_based_stream.py @@ -1,15 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import traceback import unittest +from collections.abc import Iterable, Iterator, Mapping from datetime import datetime, timezone -from typing import Any, Iterable, Iterator, Mapping +from typing import Any from unittest import mock from unittest.mock import Mock import pytest + from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.file_based.availability_strategy import ( @@ -122,8 +125,7 @@ def test_when_transform_record_then_return_updated_record(self) -> None: def test_given_exception_when_read_records_from_slice_then_do_process_other_files( self, ) -> None: - """ - The current behavior for source-s3 v3 does not fail sync on some errors and hence, we will keep this behaviour for now. One example + """The current behavior for source-s3 v3 does not fail sync on some errors and hence, we will keep this behaviour for now. One example we can easily reproduce this is by having a file with gzip extension that is not actually a gzip file. The reader will fail to open the file but the sync won't fail. Ticket: https://github.com/airbytehq/airbyte/issues/29680 @@ -150,9 +152,7 @@ def test_given_exception_when_read_records_from_slice_then_do_process_other_file def test_given_traced_exception_when_read_records_from_slice_then_fail( self, ) -> None: - """ - When a traced exception is raised, the stream shouldn't try to handle but pass it on to the caller. - """ + """When a traced exception is raised, the stream shouldn't try to handle but pass it on to the caller.""" self._parser.parse_records.side_effect = [AirbyteTracedException("An error")] with pytest.raises(AirbyteTracedException): diff --git a/unit_tests/sources/file_based/test_file_based_scenarios.py b/unit_tests/sources/file_based/test_file_based_scenarios.py index a930192f..507a0cb0 100644 --- a/unit_tests/sources/file_based/test_file_based_scenarios.py +++ b/unit_tests/sources/file_based/test_file_based_scenarios.py @@ -1,13 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from pathlib import PosixPath import pytest from _pytest.capture import CaptureFixture -from airbyte_cdk.sources.abstract_source import AbstractSource from freezegun import freeze_time + from unit_tests.sources.file_based.scenarios.avro_scenarios import ( avro_all_types_scenario, avro_file_with_double_as_number_scenario, @@ -166,6 +167,9 @@ verify_spec, ) +from airbyte_cdk.sources.abstract_source import AbstractSource + + discover_failure_scenarios = [ empty_schema_inference_scenario, ] diff --git a/unit_tests/sources/file_based/test_file_based_stream_reader.py b/unit_tests/sources/file_based/test_file_based_stream_reader.py index 66729af4..cb8fefc3 100644 --- a/unit_tests/sources/file_based/test_file_based_stream_reader.py +++ b/unit_tests/sources/file_based/test_file_based_stream_reader.py @@ -1,17 +1,22 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, Mapping from io import IOBase -from typing import Any, Dict, Iterable, List, Mapping, Optional, Set +from typing import Any import pytest +from pydantic.v1 import AnyUrl + +from unit_tests.sources.file_based.helpers import make_remote_files + from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from pydantic.v1 import AnyUrl -from unit_tests.sources.file_based.helpers import make_remote_files + reader = AbstractFileBasedStreamReader @@ -61,14 +66,14 @@ class TestStreamReader(AbstractFileBasedStreamReader): @property - def config(self) -> Optional[AbstractFileBasedSpec]: + def config(self) -> AbstractFileBasedSpec | None: return self._config @config.setter def config(self, value: AbstractFileBasedSpec) -> None: self._config = value - def get_matching_files(self, globs: List[str]) -> Iterable[RemoteFile]: + def get_matching_files(self, globs: list[str]) -> Iterable[RemoteFile]: pass def open_file(self, file: RemoteFile) -> IOBase: @@ -79,7 +84,7 @@ def file_size(self, file: RemoteFile) -> int: def get_file( self, file: RemoteFile, local_directory: str, logger: logging.Logger - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return {} @@ -352,10 +357,10 @@ def documentation_url(cls) -> AnyUrl: ], ) def test_globs_and_prefixes_from_globs( - globs: List[str], + globs: list[str], config: Mapping[str, Any], - expected_matches: Set[str], - expected_path_prefixes: Set[str], + expected_matches: set[str], + expected_path_prefixes: set[str], ) -> None: reader = TestStreamReader() reader.config = TestSpec(**config) diff --git a/unit_tests/sources/file_based/test_scenarios.py b/unit_tests/sources/file_based/test_scenarios.py index 14da7176..8ae546ff 100644 --- a/unit_tests/sources/file_based/test_scenarios.py +++ b/unit_tests/sources/file_based/test_scenarios.py @@ -1,15 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import math +from collections.abc import Mapping from pathlib import Path, PosixPath -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any import pytest from _pytest.capture import CaptureFixture from _pytest.reports import ExceptionInfo + +from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario + from airbyte_cdk.entrypoint import launch from airbyte_cdk.models import ( AirbyteAnalyticsTraceMessage, @@ -26,7 +31,6 @@ from airbyte_cdk.test.entrypoint_wrapper import read as entrypoint_read from airbyte_cdk.utils import message_utils from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario def verify_discover( @@ -109,7 +113,7 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac assert len(sorted_records) == len(sorted_expected_records) - for actual, expected in zip(sorted_records, sorted_expected_records): + for actual, expected in zip(sorted_records, sorted_expected_records, strict=False): if actual.record: assert len(actual.record.data) == len(expected["data"]) for key, value in actual.record.data.items(): @@ -136,7 +140,7 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac } == expected_states[-1] else: for actual, expected in zip( - states, expected_states + states, expected_states, strict=False ): # states should be emitted in sorted order assert {k: v for k, v in actual.state.stream.stream_state.__dict__.items()} == expected @@ -152,7 +156,7 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac def _verify_state_record_counts( - records: List[AirbyteMessage], states: List[AirbyteMessage] + records: list[AirbyteMessage], states: list[AirbyteMessage] ) -> None: actual_record_counts = {} for record in records: @@ -177,14 +181,14 @@ def _verify_state_record_counts( def _verify_analytics( - analytics: List[AirbyteMessage], - expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]], + analytics: list[AirbyteMessage], + expected_analytics: list[AirbyteAnalyticsTraceMessage] | None, ) -> None: if expected_analytics: assert ( len(analytics) == len(expected_analytics) ), f"Number of actual analytics messages ({len(analytics)}) did not match expected ({len(expected_analytics)})" - for actual, expected in zip(analytics, expected_analytics): + for actual, expected in zip(analytics, expected_analytics, strict=False): actual_type, actual_value = actual.trace.analytics.type, actual.trace.analytics.value expected_type = expected.type expected_value = expected.value @@ -193,10 +197,10 @@ def _verify_analytics( def _verify_expected_logs( - logs: List[AirbyteLogMessage], expected_logs: Optional[List[Mapping[str, Any]]] + logs: list[AirbyteLogMessage], expected_logs: list[Mapping[str, Any]] | None ) -> None: if expected_logs: - for actual, expected in zip(logs, expected_logs): + for actual, expected in zip(logs, expected_logs, strict=False): actual_level, actual_message = actual.level.value, actual.message expected_level = expected["level"] expected_message = expected["message"] @@ -236,7 +240,7 @@ def spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> def check( capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] -) -> Dict[str, Any]: +) -> dict[str, Any]: launch( scenario.source, ["check", "--config", make_file(tmp_path / "config.json", scenario.config)], @@ -245,7 +249,7 @@ def check( return _find_connection_status(captured.out.splitlines()) -def _find_connection_status(output: List[str]) -> Mapping[str, Any]: +def _find_connection_status(output: list[str]) -> Mapping[str, Any]: for line in output: json_line = json.loads(line) if "connectionStatus" in json_line: @@ -255,7 +259,7 @@ def _find_connection_status(output: List[str]) -> Mapping[str, Any]: def discover( capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] -) -> Dict[str, Any]: +) -> dict[str, Any]: launch( scenario.source, ["discover", "--config", make_file(tmp_path / "config.json", scenario.config)], @@ -285,9 +289,7 @@ def read_with_state(scenario: TestScenario[AbstractSource]) -> EntrypointOutput: ) -def make_file( - path: Path, file_contents: Optional[Union[Mapping[str, Any], List[Mapping[str, Any]]]] -) -> str: +def make_file(path: Path, file_contents: Mapping[str, Any] | list[Mapping[str, Any]] | None) -> str: path.write_text(json.dumps(file_contents)) return str(path) diff --git a/unit_tests/sources/file_based/test_schema_helpers.py b/unit_tests/sources/file_based/test_schema_helpers.py index 20bcadcc..dd060521 100644 --- a/unit_tests/sources/file_based/test_schema_helpers.py +++ b/unit_tests/sources/file_based/test_schema_helpers.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any import pytest + from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, SchemaInferenceError from airbyte_cdk.sources.file_based.schema_helpers import ( ComparableType, @@ -14,6 +17,7 @@ type_mapping_to_jsonschema, ) + COMPLETE_CONFORMING_RECORD = { "null_field": None, "boolean_field": True, @@ -353,7 +357,7 @@ def test_comparable_types() -> None: ], ) def test_merge_schemas( - schema1: SchemaType, schema2: SchemaType, expected_result: Optional[SchemaType] + schema1: SchemaType, schema2: SchemaType, expected_result: SchemaType | None ) -> None: if expected_result is not None: assert merge_schemas(schema1, schema2) == expected_result @@ -432,8 +436,8 @@ def test_merge_schemas( ) def test_type_mapping_to_jsonschema( type_mapping: Mapping[str, Any], - expected_schema: Optional[Mapping[str, Any]], - expected_exc_msg: Optional[str], + expected_schema: Mapping[str, Any] | None, + expected_exc_msg: str | None, ) -> None: if expected_exc_msg: with pytest.raises(ConfigValidationError) as exc: diff --git a/unit_tests/sources/fixtures/source_test_fixture.py b/unit_tests/sources/fixtures/source_test_fixture.py index 620ad8c4..3f6ede52 100644 --- a/unit_tests/sources/fixtures/source_test_fixture.py +++ b/unit_tests/sources/fixtures/source_test_fixture.py @@ -1,13 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging from abc import ABC -from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any import requests +from requests.auth import AuthBase + from airbyte_cdk.models import ( AirbyteStream, ConfiguredAirbyteCatalog, @@ -20,19 +24,15 @@ from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.http import HttpStream from airbyte_cdk.sources.streams.http.requests_native_auth import Oauth2Authenticator -from requests.auth import AuthBase class SourceTestFixture(AbstractSource): - """ - This is a concrete implementation of a Source connector that provides implementations of all the methods needed to run sync + """This is a concrete implementation of a Source connector that provides implementations of all the methods needed to run sync operations. For simplicity, it also overrides functions that read from files in favor of returning the data directly avoiding the need to load static files (ex. spec.yaml, config.json, configured_catalog.json) into the unit-test package. """ - def __init__( - self, streams: Optional[List[Stream]] = None, authenticator: Optional[AuthBase] = None - ): + def __init__(self, streams: list[Stream] | None = None, authenticator: AuthBase | None = None): self._streams = streams self._authenticator = authenticator @@ -76,10 +76,10 @@ def read_catalog(cls, catalog_path: str) -> ConfiguredAirbyteCatalog: ] ) - def check_connection(self, *args, **kwargs) -> Tuple[bool, Optional[Any]]: + def check_connection(self, *args, **kwargs) -> tuple[bool, Any | None]: return True, "" - def streams(self, *args, **kwargs) -> List[Stream]: + def streams(self, *args, **kwargs) -> list[Stream]: return [HttpTestStream(authenticator=self._authenticator)] @@ -87,14 +87,14 @@ class HttpTestStream(HttpStream, ABC): url_base = "https://airbyte.com/api/v1/" @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return ["updated_at"] @property def availability_strategy(self): return None - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return "id" def path( @@ -117,7 +117,7 @@ def parse_response( body = response.json() or {} return body["records"] - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def get_json_schema(self) -> Mapping[str, Any]: @@ -125,8 +125,7 @@ def get_json_schema(self) -> Mapping[str, Any]: def fixture_mock_send(self, request, **kwargs) -> requests.Response: - """ - Helper method that can be used by a test to patch the Session.send() function and mock the outbound send operation to provide + """Helper method that can be used by a test to patch the Session.send() function and mock the outbound send operation to provide faster and more reliable responses compared to actual API requests """ response = requests.Response() @@ -146,11 +145,9 @@ def fixture_mock_send(self, request, **kwargs) -> requests.Response: class SourceFixtureOauthAuthenticator(Oauth2Authenticator): - """ - Test OAuth authenticator that only overrides the request and response aspect of the authenticator flow - """ + """Test OAuth authenticator that only overrides the request and response aspect of the authenticator flow""" - def refresh_access_token(self) -> Tuple[str, int]: + def refresh_access_token(self) -> tuple[str, int]: response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), params={}) response.raise_for_status() return ( diff --git a/unit_tests/sources/message/test_repository.py b/unit_tests/sources/message/test_repository.py index 6d637a6b..d636d43f 100644 --- a/unit_tests/sources/message/test_repository.py +++ b/unit_tests/sources/message/test_repository.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import Mock import pytest + from airbyte_cdk.models import ( AirbyteControlConnectorConfigMessage, AirbyteControlMessage, @@ -20,6 +22,7 @@ NoopMessageRepository, ) + A_CONTROL = AirbyteControlMessage( type=OrchestratorType.CONNECTOR_CONFIG, emitted_at=0, @@ -119,7 +122,7 @@ def test_given_message_emitted_when_consume_queue_then_return_empty(self): class TestLogAppenderMessageRepositoryDecorator: _DICT_TO_APPEND = {"airbyte_cdk": {"stream": {"is_substream": False}}} - @pytest.fixture() + @pytest.fixture def decorated(self): return Mock(spec=MessageRepository) @@ -148,14 +151,14 @@ def test_given_log_level_is_severe_enough_when_log_message_then_allow_message_to self, decorated ): repo = LogAppenderMessageRepositoryDecorator(self._DICT_TO_APPEND, decorated, Level.DEBUG) - repo.log_message(Level.INFO, lambda: {}) + repo.log_message(Level.INFO, dict) assert decorated.log_message.call_count == 1 def test_given_log_level_not_severe_enough_when_log_message_then_do_not_allow_message_to_be_consumed( self, decorated ): repo = LogAppenderMessageRepositoryDecorator(self._DICT_TO_APPEND, decorated, Level.ERROR) - repo.log_message(Level.INFO, lambda: {}) + repo.log_message(Level.INFO, dict) assert decorated.log_message.call_count == 0 def test_when_consume_queue_then_return_delegate_queue(self, decorated): diff --git a/unit_tests/sources/mock_server_tests/mock_source_fixture.py b/unit_tests/sources/mock_server_tests/mock_source_fixture.py index b5927219..36d0c0fd 100644 --- a/unit_tests/sources/mock_server_tests/mock_source_fixture.py +++ b/unit_tests/sources/mock_server_tests/mock_source_fixture.py @@ -1,32 +1,33 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging from abc import ABC +from collections.abc import Iterable, Mapping, MutableMapping from datetime import datetime, timezone -from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple +from typing import Any import pendulum import requests +from requests import HTTPError + from airbyte_cdk.models import ConnectorSpecification, SyncMode from airbyte_cdk.sources import AbstractSource, Source from airbyte_cdk.sources.streams import CheckpointMixin, IncrementalMixin, Stream from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.streams.http import HttpStream from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy -from requests import HTTPError class FixtureAvailabilityStrategy(HttpAvailabilityStrategy): - """ - Inherit from HttpAvailabilityStrategy with slight modification to 403 error message. - """ + """Inherit from HttpAvailabilityStrategy with slight modification to 403 error message.""" def reasons_for_unavailable_status_codes( self, stream: Stream, logger: logging.Logger, source: Source, error: HTTPError - ) -> Dict[int, str]: - reasons_for_codes: Dict[int, str] = { + ) -> dict[int, str]: + reasons_for_codes: dict[int, str] = { requests.codes.FORBIDDEN: "This is likely due to insufficient permissions for your Notion integration. " "Please make sure your integration has read access for the resources you are trying to sync" } @@ -52,13 +53,12 @@ def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapp data = response.json().get("data", []) yield from data - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: has_more = response.json().get("has_more") if has_more: self.current_page += 1 return {"page": self.current_page} - else: - return None + return None class IncrementalIntegrationStream(IntegrationStream, IncrementalMixin, ABC): @@ -76,9 +76,9 @@ def state(self, value: MutableMapping[str, Any]) -> None: def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: for record in super().read_records(sync_mode, cursor_field, stream_slice, stream_state): self.state = {self.cursor_field: record.get(self.cursor_field)} @@ -127,9 +127,9 @@ def get_json_schema(self) -> Mapping[str, Any]: def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return { "start_date": stream_slice.get("start_date"), @@ -140,9 +140,9 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: start_date = pendulum.parse(self.start_date) if stream_state: @@ -166,8 +166,7 @@ def stream_slices( class Legacies(IntegrationStream): - """ - Incremental stream that uses the legacy method get_updated_state() to manage stream state. New connectors use the state + """Incremental stream that uses the legacy method get_updated_state() to manage stream state. New connectors use the state property and setter methods. """ @@ -201,17 +200,17 @@ def get_updated_state( def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: yield from super().read_records(sync_mode, cursor_field, stream_slice, stream_state) def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return { "start_date": stream_slice.get("start_date"), @@ -222,9 +221,9 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: start_date = pendulum.parse(self.start_date) if stream_state: @@ -268,16 +267,16 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: return [{"divide_category": "dukes"}, {"divide_category": "mentats"}] def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return {"category": stream_slice.get("divide_category")} @@ -325,26 +324,26 @@ def state(self, value: MutableMapping[str, Any]) -> None: def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return {"page": next_page_token.get("page")} def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: yield from self._read_single_page(cursor_field, stream_slice, stream_state) def _read_single_page( self, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: next_page_token = stream_slice request_headers = self.request_headers( @@ -387,7 +386,7 @@ def _read_single_page( self.next_page_token(response) - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: current_page = self._state.get("page") or 0 has_more = response.json().get("has_more") if has_more: @@ -399,10 +398,10 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, class SourceFixture(AbstractSource): def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, any]: + ) -> tuple[bool, any]: return True, None - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: return [ Dividers(config=config), JusticeSongs(config=config), diff --git a/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py b/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py index 4e2f0051..ab0fe845 100644 --- a/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py +++ b/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py @@ -1,14 +1,14 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # - -from typing import List +from __future__ import annotations import pytest + from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus, Type -def emits_successful_sync_status_messages(status_messages: List[AirbyteStreamStatus]) -> bool: +def emits_successful_sync_status_messages(status_messages: list[AirbyteStreamStatus]) -> bool: return ( len(status_messages) == 3 and status_messages[0] == AirbyteStreamStatus.STARTED @@ -17,7 +17,7 @@ def emits_successful_sync_status_messages(status_messages: List[AirbyteStreamSta ) -def validate_message_order(expected_message_order: List[Type], messages: List[AirbyteMessage]): +def validate_message_order(expected_message_order: list[Type], messages: list[AirbyteMessage]): if len(expected_message_order) != len(messages): pytest.fail( f"Expected message order count {len(expected_message_order)} did not match actual messages {len(messages)}" diff --git a/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py b/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py index 0670b28c..fedb0f96 100644 --- a/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py +++ b/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py @@ -1,12 +1,19 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from datetime import datetime, timedelta, timezone -from typing import List, Optional from unittest import TestCase import freezegun + +from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture +from unit_tests.sources.mock_server_tests.test_helpers import ( + emits_successful_sync_status_messages, + validate_message_order, +) + from airbyte_cdk.models import AirbyteStateBlob, ConfiguredAirbyteCatalog, SyncMode, Type from airbyte_cdk.test.catalog_builder import CatalogBuilder from airbyte_cdk.test.entrypoint_wrapper import read @@ -20,56 +27,52 @@ create_response_builder, ) from airbyte_cdk.test.state_builder import StateBuilder -from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture -from unit_tests.sources.mock_server_tests.test_helpers import ( - emits_successful_sync_status_messages, - validate_message_order, -) + _NOW = datetime.now(timezone.utc) class RequestBuilder: @classmethod - def dividers_endpoint(cls) -> "RequestBuilder": + def dividers_endpoint(cls) -> RequestBuilder: return cls("dividers") @classmethod - def justice_songs_endpoint(cls) -> "RequestBuilder": + def justice_songs_endpoint(cls) -> RequestBuilder: return cls("justice_songs") @classmethod - def legacies_endpoint(cls) -> "RequestBuilder": + def legacies_endpoint(cls) -> RequestBuilder: return cls("legacies") @classmethod - def planets_endpoint(cls) -> "RequestBuilder": + def planets_endpoint(cls) -> RequestBuilder: return cls("planets") @classmethod - def users_endpoint(cls) -> "RequestBuilder": + def users_endpoint(cls) -> RequestBuilder: return cls("users") def __init__(self, resource: str) -> None: self._resource = resource - self._start_date: Optional[datetime] = None - self._end_date: Optional[datetime] = None - self._category: Optional[str] = None - self._page: Optional[int] = None + self._start_date: datetime | None = None + self._end_date: datetime | None = None + self._category: str | None = None + self._page: int | None = None - def with_start_date(self, start_date: datetime) -> "RequestBuilder": + def with_start_date(self, start_date: datetime) -> RequestBuilder: self._start_date = start_date return self - def with_end_date(self, end_date: datetime) -> "RequestBuilder": + def with_end_date(self, end_date: datetime) -> RequestBuilder: self._end_date = end_date return self - def with_category(self, category: str) -> "RequestBuilder": + def with_category(self, category: str) -> RequestBuilder: self._category = category return self - def with_page(self, page: int) -> "RequestBuilder": + def with_page(self, page: int) -> RequestBuilder: self._page = page return self @@ -90,7 +93,7 @@ def build(self) -> HttpRequest: ) -def _create_catalog(names_and_sync_modes: List[tuple[str, SyncMode]]) -> ConfiguredAirbyteCatalog: +def _create_catalog(names_and_sync_modes: list[tuple[str, SyncMode]]) -> ConfiguredAirbyteCatalog: catalog_builder = CatalogBuilder() for stream_name, sync_mode in names_and_sync_modes: catalog_builder.with_stream(name=stream_name, sync_mode=sync_mode) diff --git a/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py b/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py index b3e3c2ac..7c403e0b 100644 --- a/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py +++ b/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py @@ -1,12 +1,20 @@ # # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any from unittest import TestCase import freezegun + +from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture +from unit_tests.sources.mock_server_tests.test_helpers import ( + emits_successful_sync_status_messages, + validate_message_order, +) + from airbyte_cdk.models import ( AirbyteStateBlob, AirbyteStreamStatus, @@ -27,25 +35,21 @@ create_response_builder, ) from airbyte_cdk.test.state_builder import StateBuilder -from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture -from unit_tests.sources.mock_server_tests.test_helpers import ( - emits_successful_sync_status_messages, - validate_message_order, -) + _NOW = datetime.now(timezone.utc) class RequestBuilder: @classmethod - def justice_songs_endpoint(cls) -> "RequestBuilder": + def justice_songs_endpoint(cls) -> RequestBuilder: return cls("justice_songs") def __init__(self, resource: str) -> None: self._resource = resource - self._page: Optional[int] = None + self._page: int | None = None - def with_page(self, page: int) -> "RequestBuilder": + def with_page(self, page: int) -> RequestBuilder: self._page = page return self @@ -61,7 +65,7 @@ def build(self) -> HttpRequest: def _create_catalog( - names_and_sync_modes: List[tuple[str, SyncMode, Dict[str, Any]]], + names_and_sync_modes: list[tuple[str, SyncMode, dict[str, Any]]], ) -> ConfiguredAirbyteCatalog: stream_builder = ConfiguredAirbyteStreamBuilder() streams = [] diff --git a/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py b/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py index 35db9c24..0ac26360 100644 --- a/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py +++ b/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py @@ -1,8 +1,10 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from unittest.mock import Mock import pytest + from airbyte_cdk.sources.streams.checkpoint import ( CursorBasedCheckpointReader, FullRefreshCheckpointReader, diff --git a/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py b/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py index f023b038..5aa3a5aa 100644 --- a/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py +++ b/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import pytest + from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import ( SubstreamResumableFullRefreshCursor, ) @@ -9,9 +11,7 @@ def test_substream_resumable_full_refresh_cursor(): - """ - Test scenario where a set of parent record partitions are iterated over by the cursor resulting in a completed sync - """ + """Test scenario where a set of parent record partitions are iterated over by the cursor resulting in a completed sync""" expected_starting_state = {"states": []} expected_ending_state = { @@ -47,9 +47,7 @@ def test_substream_resumable_full_refresh_cursor(): def test_substream_resumable_full_refresh_cursor_with_state(): - """ - Test scenario where a set of parent record partitions are iterated over and previously completed parents are skipped - """ + """Test scenario where a set of parent record partitions are iterated over and previously completed parents are skipped""" initial_state = { "states": [ { diff --git a/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py index c8bd0429..62d0f491 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py @@ -1,12 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( - ConcurrencyCompatibleStateType, -) -from airbyte_cdk.test.state_builder import StateBuilder -from airbyte_cdk.utils.traced_exception import AirbyteTracedException +from __future__ import annotations + from unit_tests.sources.file_based.scenarios.scenario_builder import ( IncrementalScenarioConfig, TestScenarioBuilder, @@ -16,6 +12,14 @@ ) from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream +from airbyte_cdk.sources.streams.concurrent.cursor import CursorField +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + ConcurrencyCompatibleStateType, +) +from airbyte_cdk.test.state_builder import StateBuilder +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + _NO_SLICE_BOUNDARIES = None _NO_INPUT_STATE = [] test_incremental_stream_without_slice_boundaries_no_input_state = ( diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 50695ba1..788c622b 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -1,10 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import concurrent import logging -from typing import Any, List, Mapping, Optional, Tuple, Union +from collections.abc import Mapping +from typing import Any + +from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder +from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( + NeverLogSliceLogger, +) from airbyte_cdk.models import ( AirbyteStateMessage, @@ -25,10 +32,7 @@ from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( EpochValueConcurrentStreamStateConverter, ) -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder -from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( - NeverLogSliceLogger, -) + _CURSOR_FIELD = "cursor_field" _NO_STATE = None @@ -41,11 +45,11 @@ class StreamFacadeConcurrentConnectorStateConverter(EpochValueConcurrentStreamSt class StreamFacadeSource(ConcurrentSourceAdapter): def __init__( self, - streams: List[Stream], + streams: list[Stream], threadpool: concurrent.futures.ThreadPoolExecutor, - cursor_field: Optional[CursorField] = None, - cursor_boundaries: Optional[Tuple[str, str]] = None, - input_state: Optional[List[Mapping[str, Any]]] = _NO_STATE, + cursor_field: CursorField | None = None, + cursor_boundaries: tuple[str, str] | None = None, + input_state: list[Mapping[str, Any]] | None = _NO_STATE, ): self._message_repository = InMemoryMessageRepository() threadpool_manager = ThreadPoolManager(threadpool, streams[0].logger) @@ -61,10 +65,10 @@ def __init__( def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: return True, None - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: state_manager = ConnectorStateManager( state=self._state, ) # The input values into the AirbyteStream are dummy values; the connector state manager only uses `name` and `namespace` @@ -88,7 +92,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: ] @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> None | MessageRepository: return self._message_repository def spec(self, logger: logging.Logger) -> ConnectorSpecification: @@ -117,30 +121,30 @@ def __init__(self): self._input_state = None self._raw_input_state = None - def set_streams(self, streams: List[Stream]) -> "StreamFacadeSourceBuilder": + def set_streams(self, streams: list[Stream]) -> StreamFacadeSourceBuilder: self._streams = streams return self - def set_max_workers(self, max_workers: int) -> "StreamFacadeSourceBuilder": + def set_max_workers(self, max_workers: int) -> StreamFacadeSourceBuilder: self._max_workers = max_workers return self def set_incremental( - self, cursor_field: CursorField, cursor_boundaries: Optional[Tuple[str, str]] - ) -> "StreamFacadeSourceBuilder": + self, cursor_field: CursorField, cursor_boundaries: tuple[str, str] | None + ) -> StreamFacadeSourceBuilder: self._cursor_field = cursor_field self._cursor_boundaries = cursor_boundaries return self - def set_input_state(self, state: List[Mapping[str, Any]]) -> "StreamFacadeSourceBuilder": + def set_input_state(self, state: list[Mapping[str, Any]]) -> StreamFacadeSourceBuilder: self._input_state = state return self def build( self, - configured_catalog: Optional[Mapping[str, Any]], - config: Optional[Mapping[str, Any]], - state: Optional[TState], + configured_catalog: Mapping[str, Any] | None, + config: Mapping[str, Any] | None, + state: TState | None, ) -> StreamFacadeSource: threadpool = concurrent.futures.ThreadPoolExecutor( max_workers=self._max_workers, thread_name_prefix="workerpool" diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py index 36fc90e9..8beaeeb7 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py @@ -1,8 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from airbyte_cdk.utils.traced_exception import AirbyteTracedException +from __future__ import annotations + from unit_tests.sources.file_based.scenarios.scenario_builder import ( IncrementalScenarioConfig, TestScenarioBuilder, @@ -12,6 +12,10 @@ ) from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream +from airbyte_cdk.sources.streams.concurrent.cursor import CursorField +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + _stream1 = MockStream( [ (None, [{"id": "1"}, {"id": "2"}]), diff --git a/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py index a0abaec0..f0ca3239 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py @@ -1,12 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from pathlib import PosixPath import pytest from _pytest.capture import CaptureFixture from freezegun import freeze_time + from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario from unit_tests.sources.file_based.test_scenarios import verify_discover, verify_read from unit_tests.sources.streams.concurrent.scenarios.incremental_scenarios import ( @@ -38,6 +40,7 @@ test_concurrent_cdk_single_stream_with_primary_key, ) + scenarios = [ test_concurrent_cdk_single_stream, test_concurrent_cdk_multiple_streams, diff --git a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py index 2de8bfd0..00ac2d73 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py @@ -1,9 +1,17 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder +from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( + ConcurrentSourceBuilder, + InMemoryPartition, + InMemoryPartitionGenerator, +) + from airbyte_cdk.sources.message import InMemoryMessageRepository from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( AlwaysAvailableAvailabilityStrategy, @@ -12,12 +20,7 @@ from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.partitions.record import Record from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder -from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( - ConcurrentSourceBuilder, - InMemoryPartition, - InMemoryPartitionGenerator, -) + _message_repository = InMemoryMessageRepository() diff --git a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py index 98633daf..03442e4c 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py @@ -1,9 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import json import logging -from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any + +from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder from airbyte_cdk.models import ( ConfiguredAirbyteCatalog, @@ -24,19 +29,18 @@ from airbyte_cdk.sources.streams.concurrent.partitions.record import Record from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.utils.slice_logger import SliceLogger -from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder class LegacyStream(Stream): - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return None def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: yield from [] @@ -44,8 +48,8 @@ def read_records( class ConcurrentCdkSource(ConcurrentSourceAdapter): def __init__( self, - streams: List[DefaultStream], - message_repository: Optional[MessageRepository], + streams: list[DefaultStream], + message_repository: MessageRepository | None, max_workers, timeout_in_seconds, ): @@ -58,11 +62,11 @@ def __init__( def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: # Check is not verified because it is up to the source to implement this method return True, None - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: return [ StreamFacade( s, @@ -104,12 +108,12 @@ def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog: ) @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> None | MessageRepository: return self._message_repository class InMemoryPartitionGenerator(PartitionGenerator): - def __init__(self, partitions: List[Partition]): + def __init__(self, partitions: list[Partition]): self._partitions = partitions def generate(self) -> Iterable[Partition]: @@ -134,7 +138,7 @@ def read(self) -> Iterable[Record]: else: yield record_or_exception - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: return self._slice def __hash__(self) -> int: @@ -142,8 +146,7 @@ def __hash__(self) -> int: # Convert the slice to a string so that it can be hashed s = json.dumps(self._slice, sort_keys=True) return hash((self._name, s)) - else: - return hash(self._name) + return hash(self._name) def close(self) -> None: self._is_closed = True @@ -154,19 +157,19 @@ def is_closed(self) -> bool: class ConcurrentSourceBuilder(SourceBuilder[ConcurrentCdkSource]): def __init__(self): - self._streams: List[DefaultStream] = [] + self._streams: list[DefaultStream] = [] self._message_repository = None - def build(self, configured_catalog: Optional[Mapping[str, Any]], _, __) -> ConcurrentCdkSource: + def build(self, configured_catalog: Mapping[str, Any] | None, _, __) -> ConcurrentCdkSource: return ConcurrentCdkSource(self._streams, self._message_repository, 1, 1) - def set_streams(self, streams: List[DefaultStream]) -> "ConcurrentSourceBuilder": + def set_streams(self, streams: list[DefaultStream]) -> ConcurrentSourceBuilder: self._streams = streams return self def set_message_repository( self, message_repository: MessageRepository - ) -> "ConcurrentSourceBuilder": + ) -> ConcurrentSourceBuilder: self._message_repository = message_repository return self diff --git a/unit_tests/sources/streams/concurrent/scenarios/utils.py b/unit_tests/sources/streams/concurrent/scenarios/utils.py index 627891ee..542e7cba 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/utils.py +++ b/unit_tests/sources/streams/concurrent/scenarios/utils.py @@ -1,7 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.streams import Stream @@ -12,7 +15,7 @@ class MockStream(Stream): def __init__( self, slices_and_records_or_exception: Iterable[ - Tuple[Optional[Mapping[str, Any]], Iterable[Union[Exception, Mapping[str, Any]]]] + tuple[Mapping[str, Any] | None, Iterable[Exception | Mapping[str, Any]]] ], name, json_schema, @@ -28,9 +31,9 @@ def __init__( def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: for _slice, records_or_exception in self._slices_and_records_or_exception: if stream_slice == _slice: @@ -40,7 +43,7 @@ def read_records( yield item @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return self._primary_key @property @@ -48,7 +51,7 @@ def name(self) -> str: return self._name @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return self._cursor_field or [] def get_json_schema(self) -> Mapping[str, Any]: @@ -58,9 +61,9 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: if self._slices_and_records_or_exception: yield from [ _slice for _slice, records_or_exception in self._slices_and_records_or_exception diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index cbebfe7c..be8f7fc5 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import datetime import logging import unittest from unittest.mock import Mock import pytest + from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.message import InMemoryMessageRepository @@ -33,6 +36,7 @@ from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + _ANY_SYNC_MODE = SyncMode.full_refresh _ANY_STATE = {"state_key": "state_value"} _ANY_CURSOR_FIELD = ["a", "cursor", "key"] diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index f6f6ecfb..a571db17 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging import unittest from unittest.mock import Mock, call import freezegun import pytest + from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -15,9 +18,11 @@ AirbyteStreamStatus, AirbyteStreamStatusTraceMessage, AirbyteTraceMessage, + StreamDescriptor, + SyncMode, + TraceType, ) from airbyte_cdk.models import Level as LogLevel -from airbyte_cdk.models import StreamDescriptor, SyncMode, TraceType from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( @@ -35,6 +40,7 @@ from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.utils.traced_exception import AirbyteTracedException + _STREAM_NAME = "stream" _ANOTHER_STREAM_NAME = "stream2" _ANY_AIRBYTE_MESSAGE = Mock(spec=AirbyteMessage) diff --git a/unit_tests/sources/streams/concurrent/test_cursor.py b/unit_tests/sources/streams/concurrent/test_cursor.py index 883f2418..f9b81d2e 100644 --- a/unit_tests/sources/streams/concurrent/test_cursor.py +++ b/unit_tests/sources/streams/concurrent/test_cursor.py @@ -1,15 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations +from collections.abc import Mapping from datetime import datetime, timedelta, timezone from functools import partial -from typing import Any, Mapping, Optional +from typing import Any from unittest import TestCase from unittest.mock import Mock import freezegun import pytest +from isodate import parse_duration + from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor @@ -28,7 +32,7 @@ EpochValueConcurrentStreamStateConverter, IsoMillisConcurrentStreamStateConverter, ) -from isodate import parse_duration + _A_STREAM_NAME = "a stream name" _A_STREAM_NAMESPACE = "a stream namespace" @@ -44,9 +48,7 @@ _NO_LOOKBACK_WINDOW = timedelta(seconds=0) -def _partition( - _slice: Optional[Mapping[str, Any]], _stream_name: Optional[str] = Mock() -) -> Partition: +def _partition(_slice: Mapping[str, Any] | None, _stream_name: str | None = Mock()) -> Partition: partition = Mock(spec=Partition) partition.to_slice.return_value = _slice partition.stream_name.return_value = _stream_name @@ -54,7 +56,7 @@ def _partition( def _record( - cursor_value: CursorValueType, partition: Optional[Partition] = Mock(spec=Partition) + cursor_value: CursorValueType, partition: Partition | None = Mock(spec=Partition) ) -> Record: return Record(data={_A_CURSOR_FIELD_KEY: cursor_value}, partition=partition) @@ -650,9 +652,7 @@ def test_last_slice_without_records_when_close_then_most_recent_cursor_value_is_ def test_most_recent_cursor_value_outside_of_boundaries_when_close_then_most_recent_cursor_value_still_considered( self, ) -> None: - """ - Not sure what is the value of this behavior but I'm simply documenting how it is today - """ + """Not sure what is the value of this behavior but I'm simply documenting how it is today""" cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=False) partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10}) cursor.observe(_record(15, partition=partition)) diff --git a/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py b/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py index 32272973..835e5951 100644 --- a/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py +++ b/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from datetime import datetime, timezone import pytest + from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( ConcurrencyCompatibleStateType, @@ -231,7 +233,9 @@ def test_convert_from_sequential_state(converter, start, sequential_state, expec ) assert conversion["state_type"] == expected_output_state["state_type"] assert conversion["legacy"] == expected_output_state["legacy"] - for actual, expected in zip(conversion["slices"], expected_output_state["slices"]): + for actual, expected in zip( + conversion["slices"], expected_output_state["slices"], strict=False + ): assert actual["start"].strftime(comparison_format) == expected["start"].strftime( comparison_format ) diff --git a/unit_tests/sources/streams/concurrent/test_default_stream.py b/unit_tests/sources/streams/concurrent/test_default_stream.py index 25b15ca2..1695e7b4 100644 --- a/unit_tests/sources/streams/concurrent/test_default_stream.py +++ b/unit_tests/sources/streams/concurrent/test_default_stream.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import unittest from unittest.mock import Mock diff --git a/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py b/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py index 02c1bdd1..b262ea3f 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py +++ b/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import unittest +from collections.abc import Callable, Iterable from queue import Queue -from typing import Callable, Iterable, List from unittest.mock import Mock, patch from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( @@ -16,7 +18,8 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem -_SOME_PARTITIONS: List[Partition] = [Mock(spec=Partition), Mock(spec=Partition)] + +_SOME_PARTITIONS: list[Partition] = [Mock(spec=Partition), Mock(spec=Partition)] _A_STREAM_NAME = "a_stream_name" @@ -88,7 +91,7 @@ def test_given_exception_when_generate_partitions_then_return_exception_and_sent ] def _partitions_before_raising( - self, partitions: List[Partition], exception: Exception + self, partitions: list[Partition], exception: Exception ) -> Callable[[], Iterable[Partition]]: def inner_function() -> Iterable[Partition]: for partition in partitions: @@ -98,13 +101,13 @@ def inner_function() -> Iterable[Partition]: return inner_function @staticmethod - def _a_stream(partitions: List[Partition]) -> AbstractStream: + def _a_stream(partitions: list[Partition]) -> AbstractStream: stream = Mock(spec=AbstractStream) stream.generate_partitions.return_value = iter(partitions) return stream - def _consume_queue(self) -> List[QueueItem]: - queue_content: List[QueueItem] = [] + def _consume_queue(self) -> list[QueueItem]: + queue_content: list[QueueItem] = [] while queue_item := self._queue.get(): if isinstance(queue_item, PartitionGenerationCompletedSentinel): queue_content.append(queue_item) diff --git a/unit_tests/sources/streams/concurrent/test_partition_reader.py b/unit_tests/sources/streams/concurrent/test_partition_reader.py index 16d70719..839dff16 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import unittest +from collections.abc import Callable, Iterable from queue import Queue -from typing import Callable, Iterable, List from unittest.mock import Mock import pytest + from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition @@ -16,6 +19,7 @@ QueueItem, ) + _RECORDS = [ Record({"id": 1, "name": "Jack"}, "stream"), Record({"id": 2, "name": "John"}, "stream"), @@ -60,14 +64,14 @@ def test_given_exception_when_process_partition_then_queue_records_and_exception PartitionCompleteSentinel(partition), ] - def _a_partition(self, records: List[Record]) -> Partition: + def _a_partition(self, records: list[Record]) -> Partition: partition = Mock(spec=Partition) partition.read.return_value = iter(records) return partition @staticmethod def _read_with_exception( - records: List[Record], exception: Exception + records: list[Record], exception: Exception ) -> Callable[[], Iterable[Record]]: def mocked_function() -> Iterable[Record]: yield from records diff --git a/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py b/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py index d4820db9..908c7e44 100644 --- a/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py +++ b/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py @@ -1,6 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from concurrent.futures import Future, ThreadPoolExecutor from unittest import TestCase from unittest.mock import Mock diff --git a/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py b/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py index ee487e97..905c6b9a 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py +++ b/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py @@ -1,10 +1,11 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - -from typing import Optional, Union +from __future__ import annotations import requests + from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, DefaultBackoffStrategy + _ANY_ATTEMPT_COUNT = 123 @@ -17,9 +18,9 @@ def test_given_no_arguments_default_backoff_strategy_returns_default_values(): class CustomBackoffStrategy(BackoffStrategy): def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: return response_or_exception.headers["Retry-After"] diff --git a/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py b/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py index 355d20b8..db962530 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py +++ b/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py @@ -1,10 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + from unittest.mock import MagicMock import pytest import requests + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.http.error_handlers import ( ErrorResolution, @@ -12,6 +15,7 @@ ResponseAction, ) + logger = MagicMock() diff --git a/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py b/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py index fecaa13f..1636e6be 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py +++ b/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest import requests + from airbyte_cdk.sources.streams.http.error_handlers import JsonErrorMessageParser diff --git a/unit_tests/sources/streams/http/error_handlers/test_response_models.py b/unit_tests/sources/streams/http/error_handlers/test_response_models.py index 7d0eb776..417f561f 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_response_models.py +++ b/unit_tests/sources/streams/http/error_handlers/test_response_models.py @@ -1,9 +1,11 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations from unittest import TestCase import requests import requests_mock + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( ResponseAction, @@ -11,6 +13,7 @@ ) from airbyte_cdk.utils.airbyte_secrets_utils import update_secrets + _A_SECRET = "a-secret" _A_URL = "https://a-url.com" diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index c8345c42..604b5c29 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -1,16 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging -from typing import Optional, Union from unittest.mock import Mock import freezegun import pendulum import pytest import requests +from requests import Response +from requests.exceptions import RequestException + from airbyte_cdk.models import FailureType, OrchestratorType, Type from airbyte_cdk.sources.streams.http.requests_native_auth import ( BasicHttpAuthenticator, @@ -20,8 +23,7 @@ TokenAuthenticator, ) from airbyte_cdk.utils import AirbyteTracedException -from requests import Response -from requests.exceptions import RequestException + LOGGER = logging.getLogger(__name__) @@ -29,9 +31,7 @@ def test_token_authenticator(): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_auth = TokenAuthenticator(token="test-token") header1 = token_auth.get_auth_header() header2 = token_auth.get_auth_header() @@ -40,15 +40,13 @@ def test_token_authenticator(): prepared_request.headers = {} token_auth(prepared_request) - assert {"Authorization": "Bearer test-token"} == prepared_request.headers - assert {"Authorization": "Bearer test-token"} == header1 - assert {"Authorization": "Bearer test-token"} == header2 + assert prepared_request.headers == {"Authorization": "Bearer test-token"} + assert header1 == {"Authorization": "Bearer test-token"} + assert header2 == {"Authorization": "Bearer test-token"} def test_basic_http_authenticator(): - """ - Should match passed in token, no matter how many times token is retrieved. - """ + """Should match passed in token, no matter how many times token is retrieved.""" token_auth = BasicHttpAuthenticator(username="user", password="password") header1 = token_auth.get_auth_header() header2 = token_auth.get_auth_header() @@ -57,9 +55,9 @@ def test_basic_http_authenticator(): prepared_request.headers = {} token_auth(prepared_request) - assert {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} == prepared_request.headers - assert {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} == header1 - assert {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} == header2 + assert prepared_request.headers == {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} + assert header1 == {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} + assert header2 == {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} def test_multiple_token_authenticator(): @@ -72,16 +70,14 @@ def test_multiple_token_authenticator(): prepared_request.headers = {} multiple_token_auth(prepared_request) - assert {"Authorization": "Bearer token2"} == prepared_request.headers - assert {"Authorization": "Bearer token1"} == header1 - assert {"Authorization": "Bearer token2"} == header2 - assert {"Authorization": "Bearer token1"} == header3 + assert prepared_request.headers == {"Authorization": "Bearer token2"} + assert header1 == {"Authorization": "Bearer token1"} + assert header2 == {"Authorization": "Bearer token2"} + assert header3 == {"Authorization": "Bearer token1"} class TestOauth2Authenticator: - """ - Test class for OAuth2Authenticator. - """ + """Test class for OAuth2Authenticator.""" refresh_endpoint = "refresh_end" client_id = "client_id" @@ -89,9 +85,7 @@ class TestOauth2Authenticator: refresh_token = "refresh_token" def test_get_auth_header_fresh(self, mocker): - """ - Should not retrieve new token if current token is valid. - """ + """Should not retrieve new token if current token is valid.""" oauth = Oauth2Authenticator( token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint, client_id=TestOauth2Authenticator.client_id, @@ -103,12 +97,10 @@ def test_get_auth_header_fresh(self, mocker): Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) ) header = oauth.get_auth_header() - assert {"Authorization": "Bearer access_token"} == header + assert header == {"Authorization": "Bearer access_token"} def test_get_auth_header_expired(self, mocker): - """ - Should retrieve new token if current token is expired. - """ + """Should retrieve new token if current token is expired.""" oauth = Oauth2Authenticator( token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint, client_id=TestOauth2Authenticator.client_id, @@ -131,12 +123,10 @@ def test_get_auth_header_expired(self, mocker): return_value=("access_token_2", valid_100_secs), ) header = oauth.get_auth_header() - assert {"Authorization": "Bearer access_token_2"} == header + assert header == {"Authorization": "Bearer access_token_2"} def test_refresh_request_body(self): - """ - Request body should match given configuration. - """ + """Request body should match given configuration.""" scopes = ["scope1", "scope2"] oauth = Oauth2Authenticator( token_refresh_endpoint="refresh_end", @@ -187,7 +177,7 @@ def test_refresh_access_token(self, mocker): token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, int) - assert ("access_token", 1000) == (token, expires_in) + assert (token, expires_in) == ("access_token", 1000) # Test with expires_in as str mocker.patch.object( @@ -196,7 +186,7 @@ def test_refresh_access_token(self, mocker): token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, str) - assert ("access_token", "2000") == (token, expires_in) + assert (token, expires_in) == ("access_token", "2000") # Test with expires_in as str mocker.patch.object( @@ -207,7 +197,7 @@ def test_refresh_access_token(self, mocker): token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, str) - assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + assert (token, expires_in) == ("access_token", "2022-04-24T00:00:00Z") @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format, expected_token_expiry_date", @@ -227,8 +217,8 @@ def test_refresh_access_token(self, mocker): def test_parse_refresh_token_lifespan( self, mocker, - expires_in_response: Union[str, int], - token_expiry_date_format: Optional[str], + expires_in_response: str | int, + token_expiry_date_format: str | None, expected_token_expiry_date: pendulum.DateTime, ): oauth = Oauth2Authenticator( @@ -297,7 +287,7 @@ def test_auth_call_method(self, mocker): prepared_request.headers = {} oauth(prepared_request) - assert {"Authorization": "Bearer access_token"} == prepared_request.headers + assert prepared_request.headers == {"Authorization": "Bearer access_token"} @pytest.mark.parametrize( ( diff --git a/unit_tests/sources/streams/http/test_availability_strategy.py b/unit_tests/sources/streams/http/test_availability_strategy.py index 766d0e35..bc87d7f3 100644 --- a/unit_tests/sources/streams/http/test_availability_strategy.py +++ b/unit_tests/sources/streams/http/test_availability_strategy.py @@ -1,17 +1,21 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import io import json import logging -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any import pytest import requests + from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy from airbyte_cdk.sources.streams.http.http import HttpStream + logger = logging.getLogger("airbyte") @@ -23,7 +27,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.resp_counter = 1 - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def path(self, **kwargs) -> str: @@ -77,8 +81,7 @@ class MockListHttpStream(MockHttpStream): def read_records(self, *args, **kvargs): if records_as_list: return list(super().read_records(*args, **kvargs)) - else: - return super().read_records(*args, **kvargs) + return super().read_records(*args, **kvargs) http_stream = MockListHttpStream() response = requests.Response() @@ -104,10 +107,10 @@ def test_http_availability_raises_unhandled_error(mocker): req.status_code = 404 mocker.patch.object(requests.Session, "send", return_value=req) - assert ( + assert HttpAvailabilityStrategy().check_availability(http_stream, logger) == ( False, "Not found. The requested resource was not found on the server.", - ) == HttpAvailabilityStrategy().check_availability(http_stream, logger) + ) def test_send_handles_retries_when_checking_availability(mocker, caplog): diff --git a/unit_tests/sources/streams/http/test_http.py b/unit_tests/sources/streams/http/test_http.py index 77a5fca3..4a16a851 100644 --- a/unit_tests/sources/streams/http/test_http.py +++ b/unit_tests/sources/streams/http/test_http.py @@ -1,16 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import json import logging +from collections.abc import Callable, Iterable, Mapping, MutableMapping from http import HTTPStatus -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any from unittest.mock import ANY, MagicMock, patch import pytest import requests + from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type from airbyte_cdk.sources.streams import CheckpointMixin from airbyte_cdk.sources.streams.checkpoint import ResumableFullRefreshCursor @@ -39,7 +41,7 @@ def __init__(self, deduplicate_query_params: bool = False, **kwargs): self.resp_counter = 1 self._deduplicate_query_params = deduplicate_query_params - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def path(self, **kwargs) -> str: @@ -54,7 +56,7 @@ def must_deduplicate_query_params(self) -> bool: return self._deduplicate_query_params @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return ["updated_at"] @@ -90,7 +92,7 @@ def test_stub_basic_read_http_stream_read_records(mocker): records = list(stream.read_records(SyncMode.full_refresh)) - assert [{"data": 1}] == records + assert records == [{"data": 1}] class StubNextPageTokenHttpStream(StubBasicReadHttpStream): @@ -100,7 +102,7 @@ def __init__(self, pages: int = 5): super().__init__() self._pages = pages - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: while self.current_page < self._pages: page_token = {"page": self.current_page} self.current_page += 1 @@ -152,7 +154,7 @@ def test_stub_bad_url_http_stream_read_records(mocker): class StubCustomBackoffHttpStream(StubBasicReadHttpStream): - def backoff_time(self, response: requests.Response) -> Optional[float]: + def backoff_time(self, response: requests.Response) -> float | None: return 0.5 @@ -180,7 +182,7 @@ class StubCustomBackoffHttpStreamRetries(StubCustomBackoffHttpStream): def max_retries(self): return retries - def get_error_handler(self) -> Optional[ErrorHandler]: + def get_error_handler(self) -> ErrorHandler | None: return HttpStatusErrorHandler(logging.Logger, max_retries=retries) stream = StubCustomBackoffHttpStreamRetries() @@ -202,7 +204,7 @@ def test_stub_custom_backoff_http_stream_endless_retries(mocker): mocker.patch("time.sleep", lambda x: None) class StubCustomBackoffHttpStreamRetries(StubCustomBackoffHttpStream): - def get_error_handler(self) -> Optional[ErrorHandler]: + def get_error_handler(self) -> ErrorHandler | None: return HttpStatusErrorHandler(logging.Logger, max_retries=99999) infinite_number = 20 @@ -233,7 +235,7 @@ class AutoFailFalseHttpStream(StubBasicReadHttpStream): raise_on_http_errors = False max_retries = 3 - def get_error_handler(self) -> Optional[ErrorHandler]: + def get_error_handler(self) -> ErrorHandler | None: return HttpStatusErrorHandler(logging.getLogger(), max_retries=3) @@ -345,7 +347,7 @@ def test_form_body(self, mocker, requests_mock): assert response["body"] == self.urlencoded_form_body def test_text_json_body(self, mocker, requests_mock): - """checks a exception if both functions were overridden""" + """Checks a exception if both functions were overridden""" stream = PostHttpStream() mocker.patch.object(stream, "request_body_data", return_value=self.data_body) mocker.patch.object(stream, "request_body_json", return_value=self.json_body) @@ -392,7 +394,7 @@ def __init__(self, parent): def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]: return [] - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def path(self, **kwargs) -> str: @@ -452,7 +454,7 @@ class CacheHttpStreamWithSlices(CacheHttpStream): def path(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> str: return f'{stream_slice["path"]}' if stream_slice else "" - def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]: + def stream_slices(self, **kwargs) -> Iterable[Mapping[str, Any] | None]: for path in self.paths: yield {"path": path} @@ -569,7 +571,7 @@ def test_send_raise_on_http_errors_logs(mocker, status_code): ({}, None), ], ) -def test_default_parse_response_error_message(api_response: dict, expected_message: Optional[str]): +def test_default_parse_response_error_message(api_response: dict, expected_message: str | None): stream = StubBasicReadHttpStream() response = MagicMock() response.json.return_value = api_response @@ -753,7 +755,7 @@ class StubParentHttpStream(HttpStream, CheckpointMixin): counter = 0 - def __init__(self, records: List[Mapping[str, Any]]): + def __init__(self, records: list[Mapping[str, Any]]): super().__init__() self._records = records self._state: MutableMapping[str, Any] = {} @@ -765,13 +767,13 @@ def url_base(self) -> str: def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/stub" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return {"__ab_full_refresh_sync_complete": True} def _read_single_page( @@ -781,12 +783,12 @@ def _read_single_page( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: yield from self._records @@ -797,8 +799,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] @@ -811,7 +813,7 @@ class StubParentResumableFullRefreshStream(HttpStream, CheckpointMixin): counter = 0 - def __init__(self, record_pages: List[List[Mapping[str, Any]]]): + def __init__(self, record_pages: list[list[Mapping[str, Any]]]): super().__init__() self._record_pages = record_pages self._state: MutableMapping[str, Any] = {} @@ -823,21 +825,21 @@ def url_base(self) -> str: def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/stub" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return {"__ab_full_refresh_sync_complete": True} def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: page_number = self.state.get("page") or 1 yield from self._record_pages[page_number - 1] @@ -852,8 +854,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] @@ -871,13 +873,13 @@ def url_base(self) -> str: def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/stub" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def _read_pages( @@ -887,12 +889,12 @@ def _read_pages( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: return [ {"id": "abc", "parent": stream_slice.get("id")}, @@ -904,8 +906,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] @@ -996,7 +998,7 @@ def __init__(self, deduplicate_query_params: bool = False, pages: int = 5, **kwa self._deduplicate_query_params = deduplicate_query_params self._pages = pages - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: current_page = self.cursor.get_stream_state().get("page", 1) if current_page < self._pages: current_page += 1 @@ -1021,15 +1023,14 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: yield from [{}] def test_resumable_full_refresh_read_from_start(mocker): - """ - Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one + """Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one page per invocation and emits state afterward. parses over """ @@ -1039,7 +1040,7 @@ def test_resumable_full_refresh_read_from_start(mocker): mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) # Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect - mocker.patch.object(stream, "_read_single_page", wraps=getattr(stream, "_read_single_page")) + mocker.patch.object(stream, "_read_single_page", wraps=stream._read_single_page) methods = ["request_params", "request_headers", "request_body_json"] for method in methods: mocker.patch.object(stream, method, wraps=getattr(stream, method)) @@ -1071,7 +1072,7 @@ def test_resumable_full_refresh_read_from_start(mocker): next_stream_slice = checkpoint_reader.next() i += 1 - assert getattr(stream, "_read_single_page").call_count == 5 + assert stream._read_single_page.call_count == 5 # Since we have 5 pages, and we don't pass in the first page, we expect 4 tokens starting at {"page":2}, {"page":3}, etc... expected_next_page_tokens = expected_checkpoints[:4] @@ -1092,8 +1093,7 @@ def test_resumable_full_refresh_read_from_start(mocker): def test_resumable_full_refresh_read_from_state(mocker): - """ - Validates the default behavior of a stream that supports resumable full refresh with an incoming state by using + """Validates the default behavior of a stream that supports resumable full refresh with an incoming state by using read_records() which gets one page per invocation and emits state afterward. parses over """ @@ -1103,7 +1103,7 @@ def test_resumable_full_refresh_read_from_state(mocker): mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) # Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect - mocker.patch.object(stream, "_read_single_page", wraps=getattr(stream, "_read_single_page")) + mocker.patch.object(stream, "_read_single_page", wraps=stream._read_single_page) methods = ["request_params", "request_headers", "request_body_json"] for method in methods: mocker.patch.object(stream, method, wraps=getattr(stream, method)) @@ -1129,7 +1129,7 @@ def test_resumable_full_refresh_read_from_state(mocker): next_stream_slice = checkpoint_reader.next() i += 1 - assert getattr(stream, "_read_single_page").call_count == 3 + assert stream._read_single_page.call_count == 3 # Since we start at page 3, we expect 3 tokens starting at {"page":3}, {"page":4}, etc... expected_next_page_tokens = [{"page": 3}, {"page": 4}, {"page": 5}] @@ -1146,8 +1146,7 @@ def test_resumable_full_refresh_read_from_state(mocker): def test_resumable_full_refresh_legacy_stream_slice(mocker): - """ - Validates the default behavior of a stream that supports resumable full refresh where incoming stream slices use the + """Validates the default behavior of a stream that supports resumable full refresh where incoming stream slices use the legacy Mapping format """ pages = 5 @@ -1156,7 +1155,7 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker): mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) # Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect - mocker.patch.object(stream, "_read_single_page", wraps=getattr(stream, "_read_single_page")) + mocker.patch.object(stream, "_read_single_page", wraps=stream._read_single_page) methods = ["request_params", "request_headers", "request_body_json"] for method in methods: mocker.patch.object(stream, method, wraps=getattr(stream, method)) @@ -1187,7 +1186,7 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker): next_stream_slice = checkpoint_reader.next() i += 1 - assert getattr(stream, "_read_single_page").call_count == 4 + assert stream._read_single_page.call_count == 4 # Since we start at page 3, we expect 3 tokens starting at {"page":3}, {"page":4}, etc... expected_next_page_tokens = [{"page": 2}, {"page": 3}, {"page": 4}, {"page": 5}] @@ -1211,7 +1210,7 @@ class StubSubstreamResumableFullRefreshStream(HttpSubStream, CheckpointMixin): def __init__( self, parent: HttpStream, - partition_id_to_child_records: Mapping[str, List[Mapping[str, Any]]], + partition_id_to_child_records: Mapping[str, list[Mapping[str, Any]]], ): super().__init__(parent=parent) self._partition_id_to_child_records = partition_id_to_child_records @@ -1224,13 +1223,13 @@ def url_base(self) -> str: def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return f"/parents/{stream_slice.get('parent_id')}/children" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None # def read_records( @@ -1250,10 +1249,10 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, def _fetch_next_page( self, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Tuple[requests.PreparedRequest, requests.Response]: + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> tuple[requests.PreparedRequest, requests.Response]: return requests.PreparedRequest(), requests.Response() def parse_response( @@ -1261,8 +1260,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: partition_id = stream_slice.get("parent").get("parent_id") if partition_id in self._partition_id_to_child_records: @@ -1275,12 +1274,10 @@ def get_json_schema(self) -> Mapping[str, Any]: def test_substream_resumable_full_refresh_read_from_start(mocker): - """ - Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one + """Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one page per invocation and emits state afterward. parses over """ - parent_records = [ {"parent_id": "100", "name": "christopher_nolan"}, {"parent_id": "101", "name": "celine_song"}, @@ -1312,7 +1309,7 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) # Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect - mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages")) + mocker.patch.object(stream, "_read_pages", wraps=stream._read_pages) checkpoint_reader = stream._get_checkpoint_reader( cursor_field=[], @@ -1373,7 +1370,7 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): next_stream_slice = checkpoint_reader.next() i += 1 - assert getattr(stream, "_read_pages").call_count == 3 + assert stream._read_pages.call_count == 3 expected = [ {"film": "interstellar", "id": "a200", "parent_id": "100"}, @@ -1390,12 +1387,10 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): def test_substream_resumable_full_refresh_read_from_state(mocker): - """ - Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one + """Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one page per invocation and emits state afterward. parses over """ - parent_records = [ {"parent_id": "100", "name": "christopher_nolan"}, {"parent_id": "101", "name": "celine_song"}, @@ -1421,7 +1416,7 @@ def test_substream_resumable_full_refresh_read_from_state(mocker): mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) # Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect - mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages")) + mocker.patch.object(stream, "_read_pages", wraps=stream._read_pages) checkpoint_reader = stream._get_checkpoint_reader( cursor_field=[], @@ -1465,7 +1460,7 @@ def test_substream_resumable_full_refresh_read_from_state(mocker): next_stream_slice = checkpoint_reader.next() i += 1 - assert getattr(stream, "_read_pages").call_count == 1 + assert stream._read_pages.call_count == 1 expected = [ {"film": "past_lives", "id": "b200", "parent_id": "101"}, @@ -1479,7 +1474,7 @@ class StubWithCursorFields(StubBasicReadHttpStream): def __init__( self, has_multiple_slices: bool, - set_cursor_field: List[str], + set_cursor_field: list[str], deduplicate_query_params: bool = False, **kwargs, ): @@ -1488,7 +1483,7 @@ def __init__( super().__init__() @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return self._cursor_field diff --git a/unit_tests/sources/streams/http/test_http_client.py b/unit_tests/sources/streams/http/test_http_client.py index 4ef9e968..9f5c1c0f 100644 --- a/unit_tests/sources/streams/http/test_http_client.py +++ b/unit_tests/sources/streams/http/test_http_client.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import logging from datetime import timedelta @@ -6,6 +7,8 @@ import pytest import requests +from requests_cache import CachedRequest + from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.call_rate import CachedLimiterSession, LimiterSession from airbyte_cdk.sources.streams.http import HttpClient @@ -22,7 +25,6 @@ ) from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from requests_cache import CachedRequest def test_http_client(): @@ -343,12 +345,11 @@ def test_send_request_given_retry_response_action_retries_and_returns_valid_resp def update_response(*args, **kwargs): if http_client._session.send.call_count == call_count: return valid_response - else: - retry_response = MagicMock(spec=requests.Response) - retry_response.ok = False - retry_response.status_code = 408 - retry_response.headers = {} - return retry_response + retry_response = MagicMock(spec=requests.Response) + retry_response.ok = False + retry_response.status_code = 408 + retry_response.headers = {} + return retry_response mocked_session.send.side_effect = update_response @@ -449,8 +450,7 @@ def test_send_request_given_request_exception_and_retry_response_action_retries_ def update_response(*args, **kwargs): if mocked_session.send.call_count == call_count: return valid_response - else: - raise requests.RequestException() + raise requests.RequestException() mocked_session.send.side_effect = update_response diff --git a/unit_tests/sources/streams/test_call_rate.py b/unit_tests/sources/streams/test_call_rate.py index 518951cd..a09dc73a 100644 --- a/unit_tests/sources/streams/test_call_rate.py +++ b/unit_tests/sources/streams/test_call_rate.py @@ -1,14 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import os import tempfile import time +from collections.abc import Iterable, Mapping from datetime import datetime, timedelta -from typing import Any, Iterable, Mapping, Optional +from typing import Any import pytest import requests +from requests import Request + from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.streams.call_rate import ( APIBudget, @@ -22,14 +27,13 @@ from airbyte_cdk.sources.streams.http import HttpStream from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH -from requests import Request class StubDummyHttpStream(HttpStream): url_base = "https://test_base_url.com" primary_key = "some_key" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return {"next_page_token": True} # endless pages def path(self, **kwargs) -> str: @@ -243,7 +247,7 @@ def test_update_available_calls(self, mocker): class TestMovingWindowCallRatePolicy: def test_no_rates(self): - """should raise a ValueError when no rates provided""" + """Should raise a ValueError when no rates provided""" with pytest.raises(ValueError, match="The list of rates can not be empty"): MovingWindowCallRatePolicy(rates=[], matchers=[]) diff --git a/unit_tests/sources/streams/test_stream_read.py b/unit_tests/sources/streams/test_stream_read.py index f079fe99..8af082a5 100644 --- a/unit_tests/sources/streams/test_stream_read.py +++ b/unit_tests/sources/streams/test_stream_read.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging +from collections.abc import Iterable, Mapping, MutableMapping from copy import deepcopy -from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any from unittest.mock import Mock import pytest + from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -38,6 +41,7 @@ from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger + _A_CURSOR_FIELD = ["NESTED", "CURSOR"] _DEFAULT_INTERNAL_CONFIG = InternalConfig() _STREAM_NAME = "STREAM" @@ -47,32 +51,32 @@ class _MockStream(Stream): def __init__( self, - slice_to_records: Mapping[str, List[Mapping[str, Any]]], - json_schema: Dict[str, Any] = None, + slice_to_records: Mapping[str, list[Mapping[str, Any]]], + json_schema: dict[str, Any] = None, ): self._slice_to_records = slice_to_records self._mocked_json_schema = json_schema or {} @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return None def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: for partition in self._slice_to_records.keys(): yield {"partition_key": partition} def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: yield from self._slice_to_records[stream_slice["partition_key"]] @@ -93,15 +97,15 @@ def state(self, value: MutableMapping[str, Any]) -> None: self._state = value @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return ["created_at"] def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: cursor = self.cursor_field[0] for record in self._slice_to_records[stream_slice["partition_key"]]: @@ -156,7 +160,7 @@ def _concurrent_stream( slice_logger, logger, message_repository, - cursor: Optional[Cursor] = None, + cursor: Cursor | None = None, ): stream = _stream(slice_to_partition_mapping, slice_logger, logger, message_repository) cursor = cursor or FinalStateCursor( @@ -623,8 +627,7 @@ def test_configured_json_schema(): def test_configured_json_schema_with_invalid_properties(): - """ - Configured Schemas can have very old fields, so we need to housekeeping ourselves. + """Configured Schemas can have very old fields, so we need to housekeeping ourselves. The purpose of this test in ensure that correct cleanup occurs when configured catalog schema is compared with current stream schema. """ old_user_insights = "old_user_insights" diff --git a/unit_tests/sources/streams/test_streams_core.py b/unit_tests/sources/streams/test_streams_core.py index 9f21ebab..866dd19c 100644 --- a/unit_tests/sources/streams/test_streams_core.py +++ b/unit_tests/sources/streams/test_streams_core.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any from unittest import mock import pytest import requests + from airbyte_cdk.models import AirbyteStream, SyncMode from airbyte_cdk.sources.streams import CheckpointMixin, Stream from airbyte_cdk.sources.streams.checkpoint import ( @@ -22,18 +25,17 @@ from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream from airbyte_cdk.sources.types import StreamSlice + logger = logging.getLogger("airbyte") class StreamStubFullRefresh(Stream): - """ - Stub full refresh class to assist with testing. - """ + """Stub full refresh class to assist with testing.""" def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -43,16 +45,14 @@ def read_records( class StreamStubIncremental(Stream, CheckpointMixin): - """ - Stub full incremental class to assist with testing. - """ + """Stub full incremental class to assist with testing.""" _state = {} def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -72,16 +72,14 @@ def state(self, value: MutableMapping[str, Any]) -> None: class StreamStubResumableFullRefresh(Stream, CheckpointMixin): - """ - Stub full incremental class to assist with testing. - """ + """Stub full incremental class to assist with testing.""" _state = {} def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -99,16 +97,14 @@ def state(self, value: MutableMapping[str, Any]) -> None: class StreamStubLegacyStateInterface(Stream): - """ - Stub full incremental class to assist with testing. - """ + """Stub full incremental class to assist with testing.""" _state = {} def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -125,14 +121,12 @@ def get_updated_state( class StreamStubIncrementalEmptyNamespace(Stream): - """ - Stub full incremental class, with empty namespace, to assist with testing. - """ + """Stub full incremental class, with empty namespace, to assist with testing.""" def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: @@ -144,9 +138,7 @@ def read_records( class HttpSubStreamStubFullRefreshLegacySlices(HttpSubStream): - """ - Stub substream full refresh class to assist with testing. - """ + """Stub substream full refresh class to assist with testing.""" primary_key = "primary_key" @@ -154,15 +146,15 @@ class HttpSubStreamStubFullRefreshLegacySlices(HttpSubStream): def url_base(self) -> str: return "https://airbyte.io/api/v1" - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: pass def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/stub" @@ -171,14 +163,14 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] class CursorBasedStreamStubFullRefresh(StreamStubFullRefresh): - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: return ResumableFullRefreshCursor() @@ -187,16 +179,14 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: yield from [{}] class MultipleSlicesStreamStub(HttpStream): - """ - Stub full refresh class that returns multiple StreamSlice instances to assist with testing. - """ + """Stub full refresh class that returns multiple StreamSlice instances to assist with testing.""" primary_key = "primary_key" @@ -208,23 +198,23 @@ def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: yield from [ StreamSlice(partition={"parent_id": "korra"}, cursor_slice={}), StreamSlice(partition={"parent_id": "asami"}, cursor_slice={}), ] - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: pass def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/stub" @@ -233,8 +223,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] @@ -246,21 +236,21 @@ class ParentHttpStreamStub(HttpStream): def read_records( self, sync_mode: SyncMode, - cursor_field: List[str] = None, + cursor_field: list[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: return [{"id": 400, "name": "a_parent_record"}] - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: return None def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: return "/parent" @@ -269,15 +259,14 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: return [] def test_as_airbyte_stream_full_refresh(mocker): - """ - Should return an full refresh AirbyteStream with information matching the + """Should return an full refresh AirbyteStream with information matching the provided Stream interface. """ test_stream = StreamStubFullRefresh() @@ -295,8 +284,7 @@ def test_as_airbyte_stream_full_refresh(mocker): def test_as_airbyte_stream_incremental(mocker): - """ - Should return an incremental refresh AirbyteStream with information matching + """Should return an incremental refresh AirbyteStream with information matching the provided Stream interface. """ test_stream = StreamStubIncremental() @@ -318,9 +306,7 @@ def test_as_airbyte_stream_incremental(mocker): def test_supports_incremental_cursor_set(): - """ - Should return true if cursor is set. - """ + """Should return true if cursor is set.""" test_stream = StreamStubIncremental() test_stream.cursor_field = "test_cursor" @@ -328,27 +314,21 @@ def test_supports_incremental_cursor_set(): def test_supports_incremental_cursor_not_set(): - """ - Should return false if cursor is not. - """ + """Should return false if cursor is not.""" test_stream = StreamStubFullRefresh() assert not test_stream.supports_incremental def test_namespace_set(): - """ - Should allow namespace property to be set. - """ + """Should allow namespace property to be set.""" test_stream = StreamStubIncremental() assert test_stream.namespace == "test_namespace" def test_namespace_set_to_empty_string(mocker): - """ - Should not set namespace property if equal to empty string. - """ + """Should not set namespace property if equal to empty string.""" test_stream = StreamStubIncremental() mocker.patch.object(StreamStubIncremental, "get_json_schema", return_value={}) @@ -370,9 +350,7 @@ def test_namespace_set_to_empty_string(mocker): def test_namespace_not_set(): - """ - Should be equal to unset value of None. - """ + """Should be equal to unset value of None.""" test_stream = StreamStubFullRefresh() assert test_stream.namespace is None @@ -387,10 +365,7 @@ def test_namespace_not_set(): ], ) def test_wrapped_primary_key_various_argument(test_input, expected): - """ - Should always wrap primary key into list of lists. - """ - + """Should always wrap primary key into list of lists.""" wrapped = Stream._wrapped_primary_key(test_input) assert wrapped == expected @@ -464,8 +439,7 @@ def test_get_checkpoint_reader(stream: Stream, stream_state, expected_checkpoint def test_checkpoint_reader_with_no_partitions(): - """ - Tests the edge case where an incremental stream might not generate any partitions, but should still attempt at least + """Tests the edge case where an incremental stream might not generate any partitions, but should still attempt at least one iteration of calling read_records() """ stream = StreamStubIncremental() diff --git a/unit_tests/sources/streams/utils/test_stream_helper.py b/unit_tests/sources/streams/utils/test_stream_helper.py index 39b642cb..4cfb5198 100644 --- a/unit_tests/sources/streams/utils/test_stream_helper.py +++ b/unit_tests/sources/streams/utils/test_stream_helper.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy diff --git a/unit_tests/sources/test_abstract_source.py b/unit_tests/sources/test_abstract_source.py index 2cc0db54..a27def7d 100644 --- a/unit_tests/sources/test_abstract_source.py +++ b/unit_tests/sources/test_abstract_source.py @@ -1,25 +1,20 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import copy import datetime import logging +from collections.abc import Callable, Iterable, Mapping, MutableMapping from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - Union, ) from unittest.mock import Mock import pytest +from pytest import fixture + from airbyte_cdk.models import ( AirbyteCatalog, AirbyteConnectionStatus, @@ -44,8 +39,8 @@ StreamDescriptor, SyncMode, TraceType, + Type, ) -from airbyte_cdk.models import Type from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources import AbstractSource from airbyte_cdk.sources.message import MessageRepository @@ -53,7 +48,7 @@ from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.utils.airbyte_secrets_utils import update_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from pytest import fixture + logger = logging.getLogger("airbyte") @@ -61,8 +56,8 @@ class MockSource(AbstractSource): def __init__( self, - check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None, - streams: List[Stream] = None, + check_lambda: Callable[[], tuple[bool, Any | None]] = None, + streams: list[Stream] = None, message_repository: MessageRepository = None, exception_on_missing_stream: bool = True, stop_sync_on_stream_failure: bool = False, @@ -75,12 +70,12 @@ def __init__( def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: if self.check_lambda: return self.check_lambda() return False, "Missing callable." - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: if not self._streams: raise Exception("Stream is not set") return self._streams @@ -176,8 +171,8 @@ def test_raising_check(mocker): class MockStream(Stream): def __init__( self, - inputs_and_mocked_outputs: List[ - Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]] + inputs_and_mocked_outputs: list[ + tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]] ] = None, name: str = None, ): @@ -201,11 +196,11 @@ def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore ) @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return "pk" @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return ["updated_at"] @@ -214,7 +209,7 @@ class MockStreamWithCursor(MockStream): def __init__( self, - inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], + inputs_and_mocked_outputs: list[tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str, ): super().__init__(inputs_and_mocked_outputs, name) @@ -223,7 +218,7 @@ def __init__( class MockStreamWithState(MockStreamWithCursor): def __init__( self, - inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], + inputs_and_mocked_outputs: list[tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str, state=None, ): @@ -242,7 +237,7 @@ def state(self, value): class MockStreamEmittingAirbyteMessages(MockStreamWithState): def __init__( self, - inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, + inputs_and_mocked_outputs: list[tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, name: str = None, state=None, ): @@ -255,7 +250,7 @@ def name(self): return self._name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return "pk" @property @@ -270,7 +265,7 @@ def state(self, value: MutableMapping[str, Any]): class MockResumableFullRefreshStream(Stream): def __init__( self, - inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Mapping[str, Any]]] = None, + inputs_and_mocked_outputs: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = None, name: str = None, ): self._inputs_and_mocked_outputs = inputs_and_mocked_outputs @@ -303,7 +298,7 @@ def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore yield from output @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return "id" @property @@ -454,14 +449,14 @@ def test_read_stream_with_error_gets_display_message(mocker): GLOBAL_EMITTED_AT = 1 -def _as_record(stream: str, data: Dict[str, Any]) -> AirbyteMessage: +def _as_record(stream: str, data: dict[str, Any]) -> AirbyteMessage: return AirbyteMessage( type=Type.RECORD, record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=GLOBAL_EMITTED_AT), ) -def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]: +def _as_records(stream: str, data: list[dict[str, Any]]) -> list[AirbyteMessage]: return [_as_record(stream, datum) for datum in data] @@ -479,7 +474,7 @@ def _as_stream_status(stream: str, status: AirbyteStreamStatus) -> AirbyteMessag return AirbyteMessage(type=MessageType.TRACE, trace=trace_message) -def _as_state(stream_name: str = "", per_stream_state: Dict[str, Any] = None): +def _as_state(stream_name: str = "", per_stream_state: dict[str, Any] = None): return AirbyteMessage( type=Type.STATE, state=AirbyteStateMessage( @@ -495,9 +490,9 @@ def _as_state(stream_name: str = "", per_stream_state: Dict[str, Any] = None): def _as_error_trace( stream: str, error_message: str, - internal_message: Optional[str], - failure_type: Optional[FailureType], - stack_trace: Optional[str], + internal_message: str | None, + failure_type: FailureType | None, + stack_trace: str | None, ) -> AirbyteMessage: trace_message = AirbyteTraceMessage( emitted_at=datetime.datetime.now().timestamp() * 1000.0, @@ -522,7 +517,7 @@ def _configured_stream(stream: Stream, sync_mode: SyncMode): ) -def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]: +def _fix_emitted_at(messages: list[AirbyteMessage]) -> list[AirbyteMessage]: for msg in messages: if msg.type == Type.RECORD and msg.record: msg.record.emitted_at = GLOBAL_EMITTED_AT @@ -668,12 +663,16 @@ def test_read_full_refresh_with_slices_sends_slice_messages(mocker, slices): messages = src.read(debug_logger, {}, catalog) - assert 2 == len( - list( - filter( - lambda message: message.log and message.log.message.startswith("slice:"), messages + assert ( + len( + list( + filter( + lambda message: message.log and message.log.message.startswith("slice:"), + messages, + ) ) ) + == 2 ) @@ -703,12 +702,16 @@ def test_read_incremental_with_slices_sends_slice_messages(mocker): messages = src.read(debug_logger, {}, catalog) - assert 2 == len( - list( - filter( - lambda message: message.log and message.log.message.startswith("slice:"), messages + assert ( + len( + list( + filter( + lambda message: message.log and message.log.message.startswith("slice:"), + messages, + ) ) ) + == 2 ) @@ -1012,9 +1015,8 @@ def test_with_slices(self, mocker): ], ) def test_no_slices(self, mocker, slices): - """ - Tests that an incremental read returns at least one state messages even if no records were read: - 1. outputs a state message after reading the entire stream + """Tests that an incremental read returns at least one state messages even if no records were read: + 1. outputs a state message after reading the entire stream """ state = {"cursor": "value"} input_state = [ @@ -1100,11 +1102,10 @@ def test_no_slices(self, mocker, slices): assert messages == expected def test_with_slices_and_interval(self, mocker): - """ - Tests that an incremental read which uses slices and a checkpoint interval: - 1. outputs all records - 2. outputs a state message every N records (N=checkpoint_interval) - 3. outputs a state message after reading the entire slice + """Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice """ input_state = [] slices = [{"1": "1"}, {"2": "2"}] @@ -1199,13 +1200,11 @@ def test_with_slices_and_interval(self, mocker): assert messages == expected def test_emit_non_records(self, mocker): + """Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice """ - Tests that an incremental read which uses slices and a checkpoint interval: - 1. outputs all records - 2. outputs a state message every N records (N=checkpoint_interval) - 3. outputs a state message after reading the entire slice - """ - input_state = [] slices = [{"1": "1"}, {"2": "2"}] stream_output = [ @@ -1319,8 +1318,7 @@ def test_emit_non_records(self, mocker): assert messages == expected def test_without_state_attribute_for_stream_with_desc_records(self, mocker): - """ - This test will check that the state resolved by get_updated_state is used and returned in the state message. + """This test will check that the state resolved by get_updated_state is used and returned in the state message. In this scenario records are returned in descending order, but we keep the "highest" cursor in the state. """ stream_cursor = MockStreamWithCursor.cursor_field @@ -1611,8 +1609,7 @@ def test_resumable_full_refresh_partial_failure(self, mocker): assert exc.value.failure_type == FailureType.config_error def test_resumable_full_refresh_skip_prior_successful_streams(self, mocker): - """ - Tests that running a resumable full refresh sync from the second attempt where one stream was successful + """Tests that running a resumable full refresh sync from the second attempt where one stream was successful and should not be synced. The other should sync beginning at the partial state passed in. """ responses = [ @@ -1773,8 +1770,7 @@ def test_resumable_full_refresh_skip_prior_successful_streams(self, mocker): def test_continue_sync_with_failed_streams( mocker, exception_to_raise, expected_error_message, expected_internal_message ): - """ - Tests that running a sync for a connector with multiple streams will continue syncing when one stream fails + """Tests that running a sync for a connector with multiple streams will continue syncing when one stream fails with an error. This source does not override the default behavior defined in the AbstractSource class. """ stream_output = [{"k1": "v1"}, {"k2": "v2"}] @@ -1827,8 +1823,7 @@ def test_continue_sync_with_failed_streams( def test_continue_sync_source_override_false(mocker): - """ - Tests that running a sync for a connector explicitly overriding the default AbstractSource.stop_sync_on_stream_failure + """Tests that running a sync for a connector explicitly overriding the default AbstractSource.stop_sync_on_stream_failure property to be False which will continue syncing stream even if one encountered an exception. """ update_secrets(["API_KEY_VALUE"]) @@ -1885,9 +1880,7 @@ def test_continue_sync_source_override_false(mocker): def test_sync_error_trace_messages_obfuscate_secrets(mocker): - """ - Tests that exceptions emitted as trace messages by a source have secrets properly sanitized - """ + """Tests that exceptions emitted as trace messages by a source have secrets properly sanitized""" update_secrets(["API_KEY_VALUE"]) stream_output = [{"k1": "v1"}, {"k2": "v2"}] @@ -1944,8 +1937,7 @@ def test_sync_error_trace_messages_obfuscate_secrets(mocker): def test_continue_sync_with_failed_streams_with_override_false(mocker): - """ - Tests that running a sync for a connector with multiple streams and stop_sync_on_stream_failure enabled stops + """Tests that running a sync for a connector with multiple streams and stop_sync_on_stream_failure enabled stops the sync when one stream fails with an error. """ stream_output = [{"k1": "v1"}, {"k2": "v2"}] @@ -2014,9 +2006,7 @@ def test_continue_sync_with_failed_streams_with_override_false(mocker): # TODO: Replace call of this function to fixture in the tests def _remove_stack_trace(message: AirbyteMessage) -> AirbyteMessage: - """ - Helper method that removes the stack trace from Airbyte trace messages to make asserting against expected records easier - """ + """Helper method that removes the stack trace from Airbyte trace messages to make asserting against expected records easier""" if message.trace and message.trace.error and message.trace.error.stack_trace: message.trace.error.stack_trace = None return message @@ -2025,9 +2015,7 @@ def _remove_stack_trace(message: AirbyteMessage) -> AirbyteMessage: def test_read_nonexistent_stream_emit_incomplete_stream_status( mocker, remove_stack_trace, as_stream_status ): - """ - Tests that attempting to sync a stream which the source does not return from the `streams` method emit incomplete stream status - """ + """Tests that attempting to sync a stream which the source does not return from the `streams` method emit incomplete stream status""" s1 = MockStream(name="s1") s2 = MockStream(name="this_stream_doesnt_exist_in_the_source") diff --git a/unit_tests/sources/test_config.py b/unit_tests/sources/test_config.py index 933bdc8f..24e18d13 100644 --- a/unit_tests/sources/test_config.py +++ b/unit_tests/sources/test_config.py @@ -1,11 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import List, Union +from pydantic.v1 import BaseModel, Field from airbyte_cdk.sources.config import BaseConfig -from pydantic.v1 import BaseModel, Field class InnerClass(BaseModel): @@ -23,15 +23,15 @@ class Choice1(BaseModel): class Choice2(BaseModel): selected_strategy = Field("option2", const=True) - sequence: List[str] + sequence: list[str] class SomeSourceConfig(BaseConfig): class Config: title = "Some Source" - items: List[InnerClass] - choice: Union[Choice1, Choice2] + items: list[InnerClass] + choice: Choice1 | Choice2 class TestBaseConfig: diff --git a/unit_tests/sources/test_connector_state_manager.py b/unit_tests/sources/test_connector_state_manager.py index 9e53f2e6..585d7ffc 100644 --- a/unit_tests/sources/test_connector_state_manager.py +++ b/unit_tests/sources/test_connector_state_manager.py @@ -1,11 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import List import pytest + from airbyte_cdk.models import ( AirbyteMessage, AirbyteStateBlob, @@ -158,7 +159,7 @@ ), ) def test_initialize_state_manager(input_stream_state, expected_stream_state, expected_error): - if isinstance(input_stream_state, List): + if isinstance(input_stream_state, list): input_stream_state = [ AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_stream_state) ] diff --git a/unit_tests/sources/test_http_logger.py b/unit_tests/sources/test_http_logger.py index 29f73e69..115b419e 100644 --- a/unit_tests/sources/test_http_logger.py +++ b/unit_tests/sources/test_http_logger.py @@ -1,11 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest import requests + from airbyte_cdk.sources.http_logger import format_http_message + A_TITLE = "a title" A_DESCRIPTION = "a description" A_STREAM_NAME = "a stream name" @@ -21,19 +24,19 @@ def __init__(self): self._request = ANY_REQUEST self._status_code = 100 - def body_content(self, body_content: bytes) -> "ResponseBuilder": + def body_content(self, body_content: bytes) -> ResponseBuilder: self._body_content = body_content return self - def headers(self, headers: dict) -> "ResponseBuilder": + def headers(self, headers: dict) -> ResponseBuilder: self._headers = headers return self - def request(self, request: requests.PreparedRequest) -> "ResponseBuilder": + def request(self, request: requests.PreparedRequest) -> ResponseBuilder: self._request = request return self - def status_code(self, status_code: int) -> "ResponseBuilder": + def status_code(self, status_code: int) -> ResponseBuilder: self._status_code = status_code return self diff --git a/unit_tests/sources/test_integration_source.py b/unit_tests/sources/test_integration_source.py index 1f86d1e7..09a238f2 100644 --- a/unit_tests/sources/test_integration_source.py +++ b/unit_tests/sources/test_integration_source.py @@ -1,17 +1,18 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import os -from typing import Any, List, Mapping +from collections.abc import Mapping +from typing import Any from unittest import mock from unittest.mock import patch import pytest import requests -from airbyte_cdk.entrypoint import launch -from airbyte_cdk.utils import AirbyteTracedException + from unit_tests.sources.fixtures.source_test_fixture import ( HttpTestStream, SourceFixtureOauthAuthenticator, @@ -19,6 +20,9 @@ fixture_mock_send, ) +from airbyte_cdk.entrypoint import launch +from airbyte_cdk.utils import AirbyteTracedException + @pytest.mark.parametrize( "deployment_mode, url_base, expected_records, expected_error", @@ -144,12 +148,10 @@ def test_external_oauth_request_source( launch(source, args) -def contains_error_trace_message(messages: List[Mapping[str, Any]], expected_error: str) -> bool: +def contains_error_trace_message(messages: list[Mapping[str, Any]], expected_error: str) -> bool: for message in messages: - if message.get("type") != "TRACE": - continue - elif message.get("trace").get("type") != "ERROR": + if message.get("type") != "TRACE" or message.get("trace").get("type") != "ERROR": continue - elif message.get("trace").get("error").get("failure_type") == expected_error: + if message.get("trace").get("error").get("failure_type") == expected_error: return True return False diff --git a/unit_tests/sources/test_source.py b/unit_tests/sources/test_source.py index c47b12a0..79cf394d 100644 --- a/unit_tests/sources/test_source.py +++ b/unit_tests/sources/test_source.py @@ -1,14 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging import tempfile +from collections.abc import Mapping, MutableMapping from contextlib import nullcontext as does_not_raise -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any import pytest +from orjson import orjson +from serpyco_rs import SchemaValidationError + from airbyte_cdk.models import ( AirbyteGlobalState, AirbyteStateBlob, @@ -26,8 +31,6 @@ from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.streams.http.http import HttpStream from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer -from orjson import orjson -from serpyco_rs import SchemaValidationError class MockSource(Source): @@ -48,13 +51,13 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]): class MockAbstractSource(AbstractSource): - def __init__(self, streams: Optional[List[Stream]] = None): + def __init__(self, streams: list[Stream] | None = None): self._streams = streams - def check_connection(self, *args, **kwargs) -> Tuple[bool, Optional[Any]]: + def check_connection(self, *args, **kwargs) -> tuple[bool, Any | None]: return True, "" - def streams(self, *args, **kwargs) -> List[Stream]: + def streams(self, *args, **kwargs) -> list[Stream]: if self._streams: return self._streams return [] @@ -104,7 +107,7 @@ class MockHttpStream(mocker.MagicMock, HttpStream): _state = {} @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: return ["updated_at"] def get_backoff_strategy(self): diff --git a/unit_tests/sources/test_source_read.py b/unit_tests/sources/test_source_read.py index a4878a8c..7b00a5ec 100644 --- a/unit_tests/sources/test_source_read.py +++ b/unit_tests/sources/test_source_read.py @@ -1,11 +1,19 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations + import logging -from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any from unittest.mock import Mock import freezegun + +from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( + NeverLogSliceLogger, +) + from airbyte_cdk.models import ( AirbyteMessage, AirbyteRecordMessage, @@ -30,13 +38,10 @@ from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.utils import AirbyteTracedException -from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( - NeverLogSliceLogger, -) class _MockStream(Stream): - def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], name: str): + def __init__(self, slice_to_records: Mapping[str, list[Mapping[str, Any]]], name: str): self._slice_to_records = slice_to_records self._name = name @@ -45,25 +50,25 @@ def name(self) -> str: return self._name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return None def stream_slices( self, *, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + cursor_field: list[str] | None = None, + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: for partition in self._slice_to_records.keys(): yield {"partition": partition} def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: for record_or_exception in self._slice_to_records[stream_slice["partition"]]: if isinstance(record_or_exception, Exception): @@ -80,13 +85,13 @@ class _MockSource(AbstractSource): def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: pass def set_streams(self, streams): self._streams = streams - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: return self._streams @@ -101,13 +106,13 @@ def __init__(self, logger): def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: pass def set_streams(self, streams): self._streams = streams - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: return self._streams diff --git a/unit_tests/sources/utils/test_record_helper.py b/unit_tests/sources/utils/test_record_helper.py index 71f882df..cda6a76f 100644 --- a/unit_tests/sources/utils/test_record_helper.py +++ b/unit_tests/sources/utils/test_record_helper.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from unittest.mock import MagicMock import pytest + from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -18,6 +20,7 @@ from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message + NOW = 1234567 STREAM_NAME = "my_stream" diff --git a/unit_tests/sources/utils/test_schema_helpers.py b/unit_tests/sources/utils/test_schema_helpers.py index 495c728e..b773cee6 100644 --- a/unit_tests/sources/utils/test_schema_helpers.py +++ b/unit_tests/sources/utils/test_schema_helpers.py @@ -1,7 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import json import logging @@ -14,6 +14,9 @@ import jsonref import pytest +from pytest import fixture +from pytest import raises as pytest_raises + from airbyte_cdk.models import ConnectorSpecification, ConnectorSpecificationSerializer, FailureType from airbyte_cdk.sources.utils.schema_helpers import ( InternalConfig, @@ -21,8 +24,7 @@ check_config_against_spec_or_exit, ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from pytest import fixture -from pytest import raises as pytest_raises + logger = logging.getLogger("airbyte") @@ -58,7 +60,7 @@ def spec_object() -> ConnectorSpecification: }, }, } - yield ConnectorSpecificationSerializer.load(spec) + return ConnectorSpecificationSerializer.load(spec) def test_check_config_against_spec_or_exit_does_not_print_schema(capsys, spec_object): diff --git a/unit_tests/sources/utils/test_slice_logger.py b/unit_tests/sources/utils/test_slice_logger.py index 43b54050..0a03bd75 100644 --- a/unit_tests/sources/utils/test_slice_logger.py +++ b/unit_tests/sources/utils/test_slice_logger.py @@ -1,10 +1,12 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import pytest + from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.utils.slice_logger import AlwaysLogSliceLogger, DebugSliceLogger @@ -24,7 +26,7 @@ ), pytest.param( DebugSliceLogger(), - logging.WARN, + logging.WARNING, False, id="debug_logger_should_not_log_if_level_is_warn", ), @@ -60,7 +62,7 @@ ), pytest.param( AlwaysLogSliceLogger(), - logging.WARN, + logging.WARNING, True, id="always_log_logger_should_log_if_level_is_warn", ), diff --git a/unit_tests/sources/utils/test_transform.py b/unit_tests/sources/utils/test_transform.py index 5d7aa1a6..b1356ca4 100644 --- a/unit_tests/sources/utils/test_transform.py +++ b/unit_tests/sources/utils/test_transform.py @@ -1,12 +1,15 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import pytest + from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + SIMPLE_SCHEMA = {"type": "object", "properties": {"value": {"type": "string"}}} COMPLEX_SCHEMA = { "type": "object", diff --git a/unit_tests/test/mock_http/test_matcher.py b/unit_tests/test/mock_http/test_matcher.py index a1018a01..9d29a345 100644 --- a/unit_tests/test/mock_http/test_matcher.py +++ b/unit_tests/test/mock_http/test_matcher.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations from unittest import TestCase from unittest.mock import Mock diff --git a/unit_tests/test/mock_http/test_mocker.py b/unit_tests/test/mock_http/test_mocker.py index 0aff08b9..adf8641b 100644 --- a/unit_tests/test/mock_http/test_mocker.py +++ b/unit_tests/test/mock_http/test_mocker.py @@ -1,11 +1,14 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations from unittest import TestCase import pytest import requests + from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse + # Ensure that the scheme is HTTP as requests only partially supports other schemes # see https://github.com/psf/requests/blob/0b4d494192de489701d3a2e32acef8fb5d3f042e/src/requests/models.py#L424-L429 _A_URL = "http://test.com/" diff --git a/unit_tests/test/mock_http/test_request.py b/unit_tests/test/mock_http/test_request.py index 15d1f667..9c2726cb 100644 --- a/unit_tests/test/mock_http/test_request.py +++ b/unit_tests/test/mock_http/test_request.py @@ -1,8 +1,10 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations from unittest import TestCase import pytest + from airbyte_cdk.test.mock_http.request import ANY_QUERY_PARAMS, HttpRequest diff --git a/unit_tests/test/mock_http/test_response_builder.py b/unit_tests/test/mock_http/test_response_builder.py index cf7fbe50..d2f22332 100644 --- a/unit_tests/test/mock_http/test_response_builder.py +++ b/unit_tests/test/mock_http/test_response_builder.py @@ -1,12 +1,15 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations + import json from copy import deepcopy from pathlib import Path as FilePath -from typing import Any, Dict, Optional, Union +from typing import Any from unittest import TestCase from unittest.mock import Mock import pytest + from airbyte_cdk.test.mock_http.response import HttpResponse from airbyte_cdk.test.mock_http.response_builder import ( FieldPath, @@ -21,6 +24,7 @@ find_template, ) + _RECORDS_FIELD = "records_field" _ID_FIELD = "record_id" _CURSOR_FIELD = "record_cursor" @@ -33,10 +37,10 @@ def _record_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - record_id_path: Optional[Path] = None, - record_cursor_path: Optional[Union[FieldPath, NestedPath]] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + record_id_path: Path | None = None, + record_cursor_path: FieldPath | NestedPath | None = None, ) -> RecordBuilder: return create_record_builder( deepcopy(response_template), records_path, record_id_path, record_cursor_path @@ -50,16 +54,16 @@ def _any_record_builder() -> RecordBuilder: def _response_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - pagination_strategy: Optional[PaginationStrategy] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + pagination_strategy: PaginationStrategy | None = None, ) -> HttpResponseBuilder: return create_response_builder( deepcopy(response_template), records_path, pagination_strategy=pagination_strategy ) -def _body(response: HttpResponse) -> Dict[str, Any]: +def _body(response: HttpResponse) -> dict[str, Any]: return json.loads(response.body) diff --git a/unit_tests/test/test_entrypoint_wrapper.py b/unit_tests/test/test_entrypoint_wrapper.py index 3ead41f5..e6c94613 100644 --- a/unit_tests/test/test_entrypoint_wrapper.py +++ b/unit_tests/test/test_entrypoint_wrapper.py @@ -1,12 +1,16 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations import json import logging import os -from typing import Any, Iterator, List, Mapping, Optional +from collections.abc import Iterator, Mapping +from typing import Any from unittest import TestCase from unittest.mock import Mock, patch +from orjson import orjson + from airbyte_cdk.models import ( AirbyteAnalyticsTraceMessage, AirbyteCatalog, @@ -31,7 +35,6 @@ from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, discover, read from airbyte_cdk.test.state_builder import StateBuilder -from orjson import orjson def _a_state_message(stream_name: str, stream_state: Mapping[str, Any]) -> AirbyteMessage: @@ -112,7 +115,7 @@ def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteM _A_LOG_MESSAGE = "a log message" -def _to_entrypoint_output(messages: List[AirbyteMessage]) -> Iterator[str]: +def _to_entrypoint_output(messages: list[AirbyteMessage]) -> Iterator[str]: return (orjson.dumps(AirbyteMessageSerializer.dump(message)).decode() for message in messages) @@ -134,8 +137,8 @@ def _validate_tmp_catalog(expected, file_path) -> None: def _create_tmp_file_validation( entrypoint, expected_config, - expected_catalog: Optional[Any] = None, - expected_state: Optional[Any] = None, + expected_catalog: Any | None = None, + expected_state: Any | None = None, ): def _validate_tmp_files(self): _validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2]) diff --git a/unit_tests/test_config_observation.py b/unit_tests/test_config_observation.py index 677e318f..4f9479b7 100644 --- a/unit_tests/test_config_observation.py +++ b/unit_tests/test_config_observation.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import time import pytest + from airbyte_cdk.config_observation import ( ConfigObserver, ObservedDict, diff --git a/unit_tests/test_connector.py b/unit_tests/test_connector.py index bc7255b9..6971ff9a 100644 --- a/unit_tests/test_connector.py +++ b/unit_tests/test_connector.py @@ -1,21 +1,24 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import json import logging import os import sys import tempfile +from collections.abc import Mapping from pathlib import Path -from typing import Any, Mapping +from typing import Any import pytest import yaml + from airbyte_cdk import Connector from airbyte_cdk.models import AirbyteConnectionStatus + logger = logging.getLogger("airbyte") MODULE = sys.modules[__name__] @@ -28,7 +31,7 @@ def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCon pass -@pytest.fixture() +@pytest.fixture def mock_config(): return {"bogus": "file"} @@ -67,7 +70,7 @@ def test_read_non_json_config(nonjson_file, integration: Connector): def test_write_config(integration, mock_config): config_path = Path(tempfile.gettempdir()) / "config.json" integration.write_config(mock_config, str(config_path)) - with open(config_path, "r") as actual: + with open(config_path) as actual: assert json.loads(actual.read()) == mock_config diff --git a/unit_tests/test_counter.py b/unit_tests/test_counter.py index f6d2c22b..82c80af5 100644 --- a/unit_tests/test_counter.py +++ b/unit_tests/test_counter.py @@ -1,7 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations from unittest import mock diff --git a/unit_tests/test_entrypoint.py b/unit_tests/test_entrypoint.py index 40781e89..2435b7b0 100644 --- a/unit_tests/test_entrypoint.py +++ b/unit_tests/test_entrypoint.py @@ -1,18 +1,22 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import os from argparse import Namespace from collections import defaultdict +from collections.abc import Mapping, MutableMapping from copy import deepcopy -from typing import Any, List, Mapping, MutableMapping, Union +from typing import Any from unittest import mock from unittest.mock import MagicMock, patch import freezegun import pytest import requests +from orjson import orjson + from airbyte_cdk import AirbyteEntrypoint from airbyte_cdk import entrypoint as entrypoint_module from airbyte_cdk.models import ( @@ -44,7 +48,6 @@ from airbyte_cdk.sources import Source from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor from airbyte_cdk.utils import AirbyteTracedException -from orjson import orjson class MockSource(Source): @@ -62,7 +65,7 @@ def message_repository(self): pass -def _as_arglist(cmd: str, named_args: Mapping[str, Any]) -> List[str]: +def _as_arglist(cmd: str, named_args: Mapping[str, Any]) -> list[str]: out = [cmd] for k, v in named_args.items(): out.append(f"--{k}") @@ -187,9 +190,10 @@ def test_parse_missing_required_args( def _wrap_message( - submessage: Union[ - AirbyteConnectionStatus, ConnectorSpecification, AirbyteRecordMessage, AirbyteCatalog - ], + submessage: AirbyteConnectionStatus + | ConnectorSpecification + | AirbyteRecordMessage + | AirbyteCatalog, ) -> str: if isinstance(submessage, AirbyteConnectionStatus): message = AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=submessage) diff --git a/unit_tests/test_exception_handler.py b/unit_tests/test_exception_handler.py index ee4bfaa1..0608a6ca 100644 --- a/unit_tests/test_exception_handler.py +++ b/unit_tests/test_exception_handler.py @@ -1,13 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import json import subprocess import sys import pytest + from airbyte_cdk.exception_handler import assemble_uncaught_exception from airbyte_cdk.models import ( AirbyteErrorTraceMessage, diff --git a/unit_tests/test_logger.py b/unit_tests/test_logger.py index 3b6db8b8..176c72e4 100644 --- a/unit_tests/test_logger.py +++ b/unit_tests/test_logger.py @@ -1,12 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import json import logging -from typing import Dict import pytest + from airbyte_cdk.logger import AirbyteLogFormatter @@ -24,7 +25,7 @@ def test_formatter(logger, caplog): formatted_record_data = json.loads(formatted_record) assert formatted_record_data.get("type") == "LOG" log = formatted_record_data.get("log") - assert isinstance(log, Dict) + assert isinstance(log, dict) level = log.get("level") message = log.get("message") assert level == "INFO" diff --git a/unit_tests/test_secure_logger.py b/unit_tests/test_secure_logger.py index 4f2abf90..1c2b4868 100644 --- a/unit_tests/test_secure_logger.py +++ b/unit_tests/test_secure_logger.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import logging import sys from argparse import Namespace -from typing import Any, Iterable, Mapping, MutableMapping +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any import pytest + from airbyte_cdk import AirbyteEntrypoint from airbyte_cdk.logger import AirbyteLogFormatter from airbyte_cdk.models import ( @@ -19,6 +22,7 @@ ) from airbyte_cdk.sources import Source + SECRET_PROPERTY = "api_token" ANOTHER_SECRET_PROPERTY = "another_api_token" ANOTHER_NOT_SECRET_PROPERTY = "not_secret_property" @@ -111,7 +115,7 @@ def check(self, **kwargs): @pytest.fixture def simple_config(): - yield { + return { SECRET_PROPERTY: I_AM_A_SECRET_VALUE, ANOTHER_SECRET_PROPERTY: ANOTHER_SECRET_VALUE, } diff --git a/unit_tests/utils/test_datetime_format_inferrer.py b/unit_tests/utils/test_datetime_format_inferrer.py index 1e69f3d1..519187ed 100644 --- a/unit_tests/utils/test_datetime_format_inferrer.py +++ b/unit_tests/utils/test_datetime_format_inferrer.py @@ -1,13 +1,14 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -from typing import Dict, List +from __future__ import annotations import pytest + from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer + NOW = 1234567 @@ -97,7 +98,7 @@ ("no scope expand", [{}, {"d": "2022-02-03"}], {}), ], ) -def test_schema_inferrer(test_name, input_records: List, expected_candidate_fields: Dict[str, str]): +def test_schema_inferrer(test_name, input_records: list, expected_candidate_fields: dict[str, str]): inferrer = DatetimeFormatInferrer() for record in input_records: inferrer.accumulate(AirbyteRecordMessage(stream="abc", data=record, emitted_at=NOW)) diff --git a/unit_tests/utils/test_mapping_helpers.py b/unit_tests/utils/test_mapping_helpers.py index f5dc979e..7053f42c 100644 --- a/unit_tests/utils/test_mapping_helpers.py +++ b/unit_tests/utils/test_mapping_helpers.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.utils.mapping_helpers import combine_mappings diff --git a/unit_tests/utils/test_message_utils.py b/unit_tests/utils/test_message_utils.py index e4567164..6c47467e 100644 --- a/unit_tests/utils/test_message_utils.py +++ b/unit_tests/utils/test_message_utils.py @@ -1,6 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations import pytest + from airbyte_cdk.models import ( AirbyteControlConnectorConfigMessage, AirbyteControlMessage, diff --git a/unit_tests/utils/test_rate_limiting.py b/unit_tests/utils/test_rate_limiting.py index d4a78140..281788a0 100644 --- a/unit_tests/utils/test_rate_limiting.py +++ b/unit_tests/utils/test_rate_limiting.py @@ -1,11 +1,13 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest -from airbyte_cdk.sources.streams.http.rate_limiting import default_backoff_handler from requests import exceptions +from airbyte_cdk.sources.streams.http.rate_limiting import default_backoff_handler + def helper_with_exceptions(exception_type): raise exception_type diff --git a/unit_tests/utils/test_schema_inferrer.py b/unit_tests/utils/test_schema_inferrer.py index 535ff41d..320999cc 100644 --- a/unit_tests/utils/test_schema_inferrer.py +++ b/unit_tests/utils/test_schema_inferrer.py @@ -1,13 +1,16 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations -from typing import List, Mapping +from collections.abc import Mapping import pytest + from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException + NOW = 1234567 @@ -267,7 +270,7 @@ ), ], ) -def test_schema_derivation(input_records: List, expected_schemas: Mapping): +def test_schema_derivation(input_records: list, expected_schemas: Mapping): inferrer = SchemaInferrer() for record in input_records: inferrer.accumulate( @@ -288,7 +291,7 @@ def test_schema_derivation(input_records: List, expected_schemas: Mapping): _IS_CURSOR_FIELD = True -def _create_inferrer_with_required_field(is_pk: bool, field: List[List[str]]) -> SchemaInferrer: +def _create_inferrer_with_required_field(is_pk: bool, field: list[list[str]]) -> SchemaInferrer: if is_pk: return SchemaInferrer(field) return SchemaInferrer([[]], field) diff --git a/unit_tests/utils/test_secret_utils.py b/unit_tests/utils/test_secret_utils.py index 846d0e12..e20ac801 100644 --- a/unit_tests/utils/test_secret_utils.py +++ b/unit_tests/utils/test_secret_utils.py @@ -1,8 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations import pytest + from airbyte_cdk.utils.airbyte_secrets_utils import ( add_to_secrets, filter_secrets, @@ -11,6 +13,7 @@ update_secrets, ) + SECRET_STRING_KEY = "secret_key1" SECRET_STRING_VALUE = "secret_value" SECRET_STRING_2_KEY = "secret_key2" diff --git a/unit_tests/utils/test_stream_status_utils.py b/unit_tests/utils/test_stream_status_utils.py index 494eb7ee..608e9dfa 100644 --- a/unit_tests/utils/test_stream_status_utils.py +++ b/unit_tests/utils/test_stream_status_utils.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from __future__ import annotations from airbyte_cdk.models import ( AirbyteMessage, @@ -14,6 +15,7 @@ as_airbyte_message as stream_status_as_airbyte_message, ) + stream = AirbyteStream( name="name", namespace="namespace", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] ) diff --git a/unit_tests/utils/test_traced_exception.py b/unit_tests/utils/test_traced_exception.py index 2d2bcc81..fe5daead 100644 --- a/unit_tests/utils/test_traced_exception.py +++ b/unit_tests/utils/test_traced_exception.py @@ -1,9 +1,11 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from __future__ import annotations import pytest +from orjson import orjson + from airbyte_cdk.models import ( AirbyteErrorTraceMessage, AirbyteMessage, @@ -16,7 +18,7 @@ ) from airbyte_cdk.models import Type as MessageType from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from orjson import orjson + _AN_EXCEPTION = ValueError("An exception") _A_STREAM_DESCRIPTOR = StreamDescriptor(name="a_stream") @@ -157,7 +159,7 @@ def test_given_both_from_exception_and_as_message_with_stream_descriptor_when_as assert message.trace.error.stream_descriptor == _A_STREAM_DESCRIPTOR -def test_given_both_from_exception_and_as_sanitized_airbyte_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> ( +def test_given_both_from_exception_and_as_sanitized_airbyte_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> ( # Line too long None ): traced_exc = AirbyteTracedException.from_exception(