Skip to content

Commit

Permalink
chore: make full CDK codebase compliant with mypy (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Dec 4, 2024
1 parent ac6cf92 commit c347760
Show file tree
Hide file tree
Showing 80 changed files with 599 additions and 435 deletions.
8 changes: 4 additions & 4 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ updates:
interval: daily
labels:
- chore
open-pull-requests-limit: 8 # default is 5
open-pull-requests-limit: 8 # default is 5

- package-ecosystem: github-actions
open-pull-requests-limit: 5 # default is 5
open-pull-requests-limit: 5 # default is 5
directory: "/"
commit-message:
prefix: "ci(deps): "
Expand All @@ -29,5 +29,5 @@ updates:
minor-and-patch:
applies-to: version-updates
update-types:
- patch
- minor
- patch
- minor
16 changes: 9 additions & 7 deletions .github/workflows/connector-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ jobs:
cdk_extra: n/a
- connector: source-shopify
cdk_extra: n/a
- connector: source-chargebee
cdk_extra: n/a
- connector: source-s3
cdk_extra: file-based
- connector: destination-pinecone
cdk_extra: vector-db-based
# Chargebee is being flaky:
# - connector: source-chargebee
# cdk_extra: n/a
# These two are behind in CDK updates and can't be used as tests until they are updated:
# - connector: source-s3
# cdk_extra: file-based
# - connector: destination-pinecone
# cdk_extra: vector-db-based
- connector: destination-motherduck
cdk_extra: sql
# ZenDesk currently failing (as of 2024-12-02)
Expand All @@ -91,7 +93,7 @@ jobs:
# - connector: source-pokeapi
# cdk_extra: n/a

name: "Check: '${{matrix.connector}}' (skip=${{needs.cdk_changes.outputs[matrix.cdk_extra] == 'false'}})"
name: "Check: '${{matrix.connector}}' (skip=${{needs.cdk_changes.outputs['src'] == 'false' || needs.cdk_changes.outputs[matrix.cdk_extra] == 'false'}})"
steps:
- name: Abort if extra not changed (${{matrix.cdk_extra}})
id: no_changes
Expand Down
13 changes: 5 additions & 8 deletions .github/workflows/publish_sdm_connector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ on:
workflow_dispatch:
inputs:
version:
description:
The version to publish, ie 1.0.0 or 1.0.0-dev1.
If omitted, and if run from a release branch, the version will be
inferred from the git tag.
If omitted, and if run from a non-release branch, then only a SHA-based
Docker tag will be created.
description: The version to publish, ie 1.0.0 or 1.0.0-dev1.
If omitted, and if run from a release branch, the version will be
inferred from the git tag.
If omitted, and if run from a non-release branch, then only a SHA-based
Docker tag will be created.
required: false
dry_run:
description: If true, the workflow will not push to DockerHub.
Expand All @@ -25,7 +24,6 @@ jobs:
build:
runs-on: ubuntu-latest
steps:

- name: Detect Release Tag Version
if: startsWith(github.ref, 'refs/tags/v')
run: |
Expand Down Expand Up @@ -168,7 +166,6 @@ jobs:
tags: |
airbyte/source-declarative-manifest:${{ env.VERSION }}
- name: Build and push ('latest' tag)
# Only run if version is set and IS_PRERELEASE is false
if: env.VERSION != '' && env.IS_PRERELEASE == 'false' && github.event.inputs.dry_run == 'false'
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/pytest_matrix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ on:
branches:
- main
paths:
- 'airbyte_cdk/**'
- 'unit_tests/**'
- 'poetry.lock'
- 'pyproject.toml'
- "airbyte_cdk/**"
- "unit_tests/**"
- "poetry.lock"
- "pyproject.toml"
pull_request:

jobs:
Expand Down
13 changes: 3 additions & 10 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ jobs:
- name: Install dependencies
run: poetry install --all-extras

# Job-specifc step(s):
# Job-specific step(s):

# For now, we run mypy only on modified files
- name: Get changed Python files
id: changed-py-files
uses: tj-actions/changed-files@v43
with:
files: "airbyte_cdk/**/*.py"
- name: Run mypy on changed files
if: steps.changed-py-files.outputs.any_changed == 'true'
run: poetry run mypy ${{ steps.changed-py-files.outputs.all_changed_files }} --config-file mypy.ini --install-types --non-interactive
- name: Run mypy
run: poetry run mypy --config-file mypy.ini airbyte_cdk
16 changes: 11 additions & 5 deletions airbyte_cdk/cli/source_declarative_manifest/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pathlib import Path
from typing import Any, cast

from orjson import orjson
import orjson

from airbyte_cdk.entrypoint import AirbyteEntrypoint, launch
from airbyte_cdk.models import (
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
super().__init__(
catalog=catalog,
config=config,
state=state,
state=state, # type: ignore [arg-type]
path_to_yaml="manifest.yaml",
)

Expand Down Expand Up @@ -152,18 +152,24 @@ def handle_remote_manifest_command(args: list[str]) -> None:
)


def create_declarative_source(args: list[str]) -> ConcurrentDeclarativeSource:
def create_declarative_source(
args: list[str],
) -> ConcurrentDeclarativeSource: # type: ignore [type-arg]
"""Creates the source with the injected config.
This essentially does what other low-code sources do at build time, but at runtime,
with a user-provided manifest in the config. This better reflects what happens in the
connector builder.
"""
try:
config: Mapping[str, Any] | None
catalog: ConfiguredAirbyteCatalog | None
state: list[AirbyteStateMessage]
config, catalog, state = _parse_inputs_into_config_catalog_state(args)
if "__injected_declarative_manifest" not in config:
if config is None or "__injected_declarative_manifest" not in config:
raise ValueError(
f"Invalid config: `__injected_declarative_manifest` should be provided at the root of the config but config only has keys {list(config.keys())}"
"Invalid config: `__injected_declarative_manifest` should be provided at the root "
f"of the config but config only has keys: {list(config.keys() if config else [])}"
)
return ConcurrentDeclarativeSource(
config=config,
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copy import copy
from typing import Any, List, MutableMapping

from orjson import orjson
import orjson

from airbyte_cdk.models import (
AirbyteControlConnectorConfigMessage,
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Any, List, Mapping, Optional, Tuple

from orjson import orjson
import orjson

from airbyte_cdk.connector import BaseConnector
from airbyte_cdk.connector_builder.connector_builder_handler import (
Expand Down
20 changes: 10 additions & 10 deletions airbyte_cdk/connector_builder/message_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _cursor_field_to_nested_and_composite_field(

is_nested_key = isinstance(field[0], str)
if is_nested_key:
return [field] # type: ignore # the type of field is expected to be List[str] here
return [field]

raise ValueError(f"Unknown type for cursor field `{field}")

Expand Down Expand Up @@ -232,9 +232,9 @@ def _get_message_groups(
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
current_slice_pages = []
at_least_one_page_in_group = False
elif message.type == MessageType.LOG and message.log.message.startswith(
elif message.type == MessageType.LOG and message.log.message.startswith( # type: ignore[union-attr] # None doesn't have 'message'
SliceLogger.SLICE_LOG_PREFIX
): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
):
# parsing the first slice
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
elif message.type == MessageType.LOG:
Expand Down Expand Up @@ -274,14 +274,14 @@ def _get_message_groups(
if message.trace.type == TraceType.ERROR: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has trace.type
yield message.trace
elif message.type == MessageType.RECORD:
current_page_records.append(message.record.data) # type: ignore[union-attr] # AirbyteMessage with MessageType.RECORD has record.data
current_page_records.append(message.record.data) # type: ignore[arg-type, union-attr] # AirbyteMessage with MessageType.RECORD has record.data
records_count += 1
schema_inferrer.accumulate(message.record)
datetime_format_inferrer.accumulate(message.record)
elif (
message.type == MessageType.CONTROL
and message.control.type == OrchestratorType.CONNECTOR_CONFIG
): # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type
and message.control.type == OrchestratorType.CONNECTOR_CONFIG # type: ignore[union-attr] # None doesn't have 'type'
):
yield message.control
elif message.type == MessageType.STATE:
latest_state_message = message.state # type: ignore[assignment]
Expand Down Expand Up @@ -310,8 +310,8 @@ def _need_to_close_page(
and message.type == MessageType.LOG
and (
MessageGrouper._is_page_http_request(json_message)
or message.log.message.startswith("slice:")
) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
or message.log.message.startswith("slice:") # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
)
)

@staticmethod
Expand Down Expand Up @@ -355,8 +355,8 @@ def _close_page(
StreamReadPages(
request=current_page_request,
response=current_page_response,
records=deepcopy(current_page_records),
) # type: ignore
records=deepcopy(current_page_records), # type: ignore [arg-type]
)
)
current_page_records.clear()

Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Mapping

from orjson import orjson
import orjson

from airbyte_cdk.connector import Connector
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
Expand Down
4 changes: 2 additions & 2 deletions airbyte_cdk/destinations/vector_db_based/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def embedding_dimensions(self) -> int:
class OpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
super().__init__(
OpenAIEmbeddings(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key, max_retries=15, disallowed_special=()
),
chunk_size,
Expand All @@ -118,7 +118,7 @@ class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(
OpenAIEmbeddings(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key,
chunk_size=16,
max_retries=15,
Expand Down
16 changes: 12 additions & 4 deletions airbyte_cdk/destinations/vector_db_based/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,19 @@ def write(
yield message
elif message.type == Type.RECORD:
record_chunks, record_id_to_delete = self.processor.process(message.record)
self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks)
if record_id_to_delete is not None:
self.ids_to_delete[(message.record.namespace, message.record.stream)].append(
record_id_to_delete
self.chunks[
( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]"
message.record.namespace, # type: ignore [union-attr] # record not None
message.record.stream, # type: ignore [union-attr] # record not None
)
].extend(record_chunks)
if record_id_to_delete is not None:
self.ids_to_delete[
( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]"
message.record.namespace, # type: ignore [union-attr] # record not None
message.record.stream, # type: ignore [union-attr] # record not None
)
].append(record_id_to_delete)
self.number_of_chunks += len(record_chunks)
if self.number_of_chunks >= self.batch_size:
self._process_batch()
Expand Down
13 changes: 7 additions & 6 deletions airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from airbyte_cdk.connector import TConfig
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.logger import init_logger
from airbyte_cdk.models import ( # type: ignore [attr-defined]
from airbyte_cdk.models import (
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteMessageSerializer,
Expand Down Expand Up @@ -255,9 +255,10 @@ def handle_record_counts(

stream_message_count[
HashableStreamDescriptor(
name=message.record.stream, namespace=message.record.namespace
name=message.record.stream, # type: ignore[union-attr] # record has `stream`
namespace=message.record.namespace, # type: ignore[union-attr] # record has `namespace`
)
] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace`
] += 1.0
case Type.STATE:
if message.state is None:
raise ValueError("State message must have a state attribute")
Expand All @@ -266,9 +267,9 @@ def handle_record_counts(

# Set record count from the counter onto the state message
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats() # type: ignore[union-attr] # state has `sourceStats`
message.state.sourceStats.recordCount = stream_message_count.get(
message.state.sourceStats.recordCount = stream_message_count.get( # type: ignore[union-attr] # state has `sourceStats`
stream_descriptor, 0.0
) # type: ignore[union-attr] # state has `sourceStats`
)

# Reset the counter
stream_message_count[stream_descriptor] = 0.0
Expand All @@ -290,7 +291,7 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str,

@staticmethod
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string
return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode()

@classmethod
def extract_state(cls, args: List[str]) -> Optional[Any]:
Expand Down
4 changes: 2 additions & 2 deletions airbyte_cdk/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging.config
from typing import Any, Callable, Mapping, Optional, Tuple

from orjson import orjson
import orjson

from airbyte_cdk.models import (
AirbyteLogMessage,
Expand Down Expand Up @@ -78,7 +78,7 @@ def format(self, record: logging.LogRecord) -> str:
log_message = AirbyteMessage(
type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message)
)
return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string
return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode()

@staticmethod
def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def read(
if len(stream_name_to_exception) > 0:
error_message = generate_failed_streams_error_message(
{key: [value] for key, value in stream_name_to_exception.items()}
) # type: ignore # for some reason, mypy can't figure out the types for key and value
)
logger.info(error_message)
# We still raise at least one exception when a stream raises an exception because the platform currently relies
# on a non-zero exit code to determine if a sync attempt has failed. We also raise the exception as a config_error
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/sources/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf
expand_refs(schema)
schema.pop("description", None) # description added from the docstring
return schema # type: ignore[no-any-return]
return schema
Loading

0 comments on commit c347760

Please sign in to comment.