Skip to content

Commit

Permalink
auto-fix lint issues with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Nov 10, 2024
1 parent fb99f92 commit 6bbb030
Show file tree
Hide file tree
Showing 214 changed files with 1,905 additions and 1,107 deletions.
2 changes: 1 addition & 1 deletion airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __setitem__(self, item: Any, value: Any) -> None:
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, self.observer)
super(ObservedDict, self).__setitem__(item, value)
super().__setitem__(item, value)
if self.update_on_unchanged_value or value != previous_value:
self.observer.update()

Expand Down
11 changes: 7 additions & 4 deletions airbyte_cdk/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from __future__ import annotations

import json
import logging
import os
import pkgutil
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, Generic, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar

import yaml

Expand All @@ -20,6 +19,10 @@
)


if TYPE_CHECKING:
import logging


def load_optional_package_file(package: str, filename: str) -> bytes | None:
"""Gets a resource from a package, returning None if it does not exist"""
try:
Expand Down Expand Up @@ -50,7 +53,7 @@ def read_config(config_path: str) -> Mapping[str, Any]:

@staticmethod
def _read_json_file(file_path: str) -> Any:
with open(file_path) as file:
with open(file_path, encoding="utf-8") as file:
contents = file.read()

try:
Expand All @@ -62,7 +65,7 @@ def _read_json_file(file_path: str) -> Any:

@staticmethod
def write_config(config: TConfig, config_path: str) -> None:
with open(config_path, "w") as fh:
with open(config_path, "w", encoding="utf-8") as fh:
fh.write(json.dumps(config))

def spec(self, logger: logging.Logger) -> ConnectorSpecification:
Expand Down
10 changes: 7 additions & 3 deletions airbyte_cdk/connector_builder/connector_builder_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from __future__ import annotations

import dataclasses
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any

from airbyte_cdk.connector_builder.message_grouper import MessageGrouper
from airbyte_cdk.models import (
Expand All @@ -17,7 +16,6 @@
Type,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
ModelToComponentFactory,
Expand All @@ -26,6 +24,12 @@
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


if TYPE_CHECKING:
from collections.abc import Mapping

from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource


DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
DEFAULT_MAXIMUM_RECORDS = 100
Expand Down
14 changes: 10 additions & 4 deletions airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from __future__ import annotations

import sys
from collections.abc import Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

from orjson import orjson
import orjson

from airbyte_cdk.connector import BaseConnector
from airbyte_cdk.connector_builder.connector_builder_handler import (
Expand All @@ -25,11 +24,18 @@
ConfiguredAirbyteCatalog,
ConfiguredAirbyteCatalogSerializer,
)
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.source import Source
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


if TYPE_CHECKING:
from collections.abc import Mapping

from airbyte_cdk.sources.declarative.manifest_declarative_source import (
ManifestDeclarativeSource,
)


def get_config_and_catalog_from_args(
args: list[str],
) -> tuple[str, Mapping[str, Any], ConfiguredAirbyteCatalog | None, Any]:
Expand Down
19 changes: 14 additions & 5 deletions airbyte_cdk/connector_builder/message_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import json
import logging
from collections.abc import Iterable, Iterator, Mapping
from copy import deepcopy
from json import JSONDecodeError
from typing import Any
from typing import TYPE_CHECKING, Any

from airbyte_cdk.connector_builder.models import (
AuxiliaryRequest,
Expand All @@ -31,18 +30,28 @@
TraceType,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException


if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping

from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.utils.types import JsonType


class MessageGrouper:
logger = logging.getLogger("airbyte.connector-builder")

def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000):
def __init__(
self,
max_pages_per_slice: int,
max_slices: int,
max_record_limit: int = 1000,
) -> None:
self._max_pages_per_slice = max_pages_per_slice
self._max_slices = max_slices
self._max_record_limit = max_record_limit
Expand Down
13 changes: 8 additions & 5 deletions airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

from orjson import orjson

Expand All @@ -26,6 +25,10 @@
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


if TYPE_CHECKING:
from collections.abc import Iterable, Mapping


logger = logging.getLogger("airbyte")


Expand Down Expand Up @@ -62,7 +65,7 @@ def _run_write(
input_stream: io.TextIOWrapper,
) -> Iterable[AirbyteMessage]:
catalog = ConfiguredAirbyteCatalogSerializer.load(
orjson.loads(open(configured_catalog_path).read())
orjson.loads(open(configured_catalog_path, encoding="utf-8").read())
)
input_messages = self._parse_input_stream(input_stream)
logger.info("Begin writing to the destination...")
Expand Down Expand Up @@ -109,7 +112,7 @@ def parse_args(self, args: list[str]) -> argparse.Namespace:
cmd = parsed_args.command
if not cmd:
raise Exception("No command entered. ")
if cmd not in ["spec", "check", "write"]:
if cmd not in {"spec", "check", "write"}:
# This is technically dead code since parse_args() would fail if this was the case
# But it's non-obvious enough to warrant placing it here anyways
raise Exception(f"Unknown command entered: {cmd}")
Expand All @@ -134,7 +137,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
if connection_status and cmd == "check":
yield connection_status
return
raise traced_exc
raise

if cmd == "check":
yield self._run_check(config=config)
Expand Down
18 changes: 12 additions & 6 deletions airbyte_cdk/destinations/vector_db_based/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import json
import logging
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any

import dpath
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
Expand All @@ -29,6 +28,10 @@
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType


if TYPE_CHECKING:
from collections.abc import Mapping


METADATA_STREAM_FIELD = "_ab_stream"
METADATA_RECORD_ID_FIELD = "_ab_record_id"

Expand Down Expand Up @@ -116,8 +119,13 @@ def _get_text_splitter(
),
disallowed_special=(),
)
return None

def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog):
def __init__(
self,
config: ProcessingConfigModel,
catalog: ConfiguredAirbyteCatalog,
) -> None:
self.streams = {
create_stream_identifier(stream.stream): stream for stream in catalog.streams
}
Expand Down Expand Up @@ -154,9 +162,7 @@ def process(self, record: AirbyteRecordMessage) -> tuple[list[Chunk], str | None
for chunk_document in self._split_document(doc)
]
id_to_delete = (
doc.metadata[METADATA_RECORD_ID_FIELD]
if METADATA_RECORD_ID_FIELD in doc.metadata
else None
doc.metadata.get(METADATA_RECORD_ID_FIELD, None)
)
return chunks, id_to_delete

Expand Down
55 changes: 35 additions & 20 deletions airbyte_cdk/destinations/vector_db_based/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,30 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import cast
from typing import TYPE_CHECKING, cast

from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.localai import LocalAIEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings

from airbyte_cdk.destinations.vector_db_based.config import (
AzureOpenAIEmbeddingConfigModel,
CohereEmbeddingConfigModel,
FakeEmbeddingConfigModel,
FromFieldEmbeddingConfigModel,
OpenAICompatibleEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
ProcessingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception
from airbyte_cdk.models import AirbyteRecordMessage
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType


if TYPE_CHECKING:
from airbyte_cdk.destinations.vector_db_based.config import (
AzureOpenAIEmbeddingConfigModel,
CohereEmbeddingConfigModel,
FakeEmbeddingConfigModel,
FromFieldEmbeddingConfigModel,
OpenAICompatibleEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
ProcessingConfigModel,
)
from airbyte_cdk.models import AirbyteRecordMessage


@dataclass
class Document:
page_content: str
Expand Down Expand Up @@ -67,7 +70,11 @@ def embedding_dimensions(self) -> int:


class BaseOpenAIEmbedder(Embedder):
def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int):
def __init__(
self,
embeddings: OpenAIEmbeddings,
chunk_size: int,
) -> None:
super().__init__()
self.embeddings = embeddings
self.chunk_size = chunk_size
Expand Down Expand Up @@ -103,7 +110,11 @@ def embedding_dimensions(self) -> int:


class OpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
def __init__(
self,
config: OpenAIEmbeddingConfigModel,
chunk_size: int,
) -> None:
super().__init__(
OpenAIEmbeddings(
openai_api_key=config.openai_key, max_retries=15, disallowed_special=()
Expand All @@ -113,7 +124,11 @@ def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):


class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
def __init__(
self,
config: AzureOpenAIEmbeddingConfigModel,
chunk_size: int,
) -> None:
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(
OpenAIEmbeddings(
Expand All @@ -134,7 +149,7 @@ def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):


class CohereEmbedder(Embedder):
def __init__(self, config: CohereEmbeddingConfigModel):
def __init__(self, config: CohereEmbeddingConfigModel) -> None:
super().__init__()
# Client is set internally
self.embeddings = CohereEmbeddings(
Expand All @@ -161,7 +176,7 @@ def embedding_dimensions(self) -> int:


class FakeEmbedder(Embedder):
def __init__(self, config: FakeEmbeddingConfigModel):
def __init__(self, config: FakeEmbeddingConfigModel) -> None:
super().__init__()
self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE)

Expand All @@ -188,7 +203,7 @@ def embedding_dimensions(self) -> int:


class OpenAICompatibleEmbedder(Embedder):
def __init__(self, config: OpenAICompatibleEmbeddingConfigModel):
def __init__(self, config: OpenAICompatibleEmbeddingConfigModel) -> None:
super().__init__()
self.config = config
# Client is set internally
Expand Down Expand Up @@ -228,7 +243,7 @@ def embedding_dimensions(self) -> int:


class FromFieldEmbedder(Embedder):
def __init__(self, config: FromFieldEmbeddingConfigModel):
def __init__(self, config: FromFieldEmbeddingConfigModel) -> None:
super().__init__()
self.config = config

Expand All @@ -249,7 +264,7 @@ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
)
field = data[self.config.field_name]
if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
if not isinstance(field, list) or not all(isinstance(x, int | float) for x in field):
raise AirbyteTracedException(
internal_message="Embedding vector field not a list of numbers",
failure_type=FailureType.config_error,
Expand Down Expand Up @@ -289,7 +304,7 @@ def create_from_config(
| OpenAICompatibleEmbeddingConfigModel,
processing_config: ProcessingConfigModel,
) -> Embedder:
if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai":
if embedding_config.mode in {"azure_openai", "openai"}:
return cast(
Embedder,
embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size),
Expand Down
Loading

0 comments on commit 6bbb030

Please sign in to comment.