Skip to content

Commit

Permalink
Auto-fix lint and format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
octavia-squidington-iii committed Nov 10, 2024
1 parent 46f0c47 commit aded694
Show file tree
Hide file tree
Showing 74 changed files with 1,481 additions and 1,615 deletions.
6 changes: 2 additions & 4 deletions airbyte_cdk/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionSta

class _WriteConfigProtocol(Protocol):
@staticmethod
def write_config(config: Mapping[str, Any], config_path: str) -> None:
...
def write_config(config: Mapping[str, Any], config_path: str) -> None: ...


class DefaultConnectorMixin:
Expand All @@ -108,5 +107,4 @@ def configure(self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: s
return config


class Connector(DefaultConnectorMixin, BaseConnector[Mapping[str, Any]], ABC):
...
class Connector(DefaultConnectorMixin, BaseConnector[Mapping[str, Any]], ABC): ...
4 changes: 3 additions & 1 deletion airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def handle_request(args: List[str]) -> str:
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
return orjson.dumps(AirbyteMessageSerializer.dump(handle_connector_builder_request(source, command, config, catalog, state, limits))).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
return orjson.dumps(
AirbyteMessageSerializer.dump(handle_connector_builder_request(source, command, config, catalog, state, limits))
).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def parse_args(self, args: List[str]) -> argparse.Namespace:
return parsed_args

def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:

cmd = parsed_args.command
if cmd not in self.VALID_CMDS:
raise Exception(f"Unrecognized command: {cmd}")
Expand Down
23 changes: 20 additions & 3 deletions airbyte_cdk/destinations/vector_db_based/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,19 @@ def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(OpenAIEmbeddings(openai_api_key=config.openai_key, chunk_size=16, max_retries=15, openai_api_type="azure", openai_api_version="2023-05-15", openai_api_base=config.api_base, deployment=config.deployment, disallowed_special=()), chunk_size) # type: ignore
super().__init__(
OpenAIEmbeddings(
openai_api_key=config.openai_key,
chunk_size=16,
max_retries=15,
openai_api_type="azure",
openai_api_version="2023-05-15",
openai_api_base=config.api_base,
deployment=config.deployment,
disallowed_special=(),
),
chunk_size,
) # type: ignore


COHERE_VECTOR_SIZE = 1024
Expand Down Expand Up @@ -167,7 +179,13 @@ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel):
self.config = config
# Client is set internally
# Always set an API key even if there is none defined in the config because the validator will fail otherwise. Embedding APIs that don't require an API key don't fail if one is provided, so this is not breaking usage.
self.embeddings = LocalAIEmbeddings(model=config.model_name, openai_api_key=config.api_key or "dummy-api-key", openai_api_base=config.base_url, max_retries=15, disallowed_special=()) # type: ignore
self.embeddings = LocalAIEmbeddings(
model=config.model_name,
openai_api_key=config.api_key or "dummy-api-key",
openai_api_base=config.base_url,
max_retries=15,
disallowed_special=(),
) # type: ignore

def check(self) -> Optional[str]:
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
Expand Down Expand Up @@ -254,7 +272,6 @@ def create_from_config(
],
processing_config: ProcessingConfigModel,
) -> Embedder:

if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai":
return cast(Embedder, embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size))
else:
Expand Down
6 changes: 3 additions & 3 deletions airbyte_cdk/models/airbyte_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class AirbyteGlobalState:
class AirbyteStateMessage:
type: Optional[AirbyteStateType] = None # type: ignore [name-defined]
stream: Optional[AirbyteStreamState] = None
global_: Annotated[
AirbyteGlobalState | None, Alias("global")
] = None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization
global_: Annotated[AirbyteGlobalState | None, Alias("global")] = (
None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization
)
data: Optional[Dict[str, Any]] = None
sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
Expand Down
3 changes: 2 additions & 1 deletion airbyte_cdk/sources/connector_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _extract_from_state_message(
else:
streams = {
HashableStreamDescriptor(
name=per_stream_state.stream.stream_descriptor.name, namespace=per_stream_state.stream.stream_descriptor.namespace # type: ignore[union-attr] # stream has stream_descriptor
name=per_stream_state.stream.stream_descriptor.name,
namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor
): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state
for per_stream_state in state
if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]):

# By default, we defer to a value of 1 which represents running a connector using the Concurrent CDK engine on only one thread.
SINGLE_THREADED_CONCURRENCY_LEVEL = 1

Expand Down Expand Up @@ -99,7 +98,6 @@ def read(
catalog: ConfiguredAirbyteCatalog,
state: Optional[Union[List[AirbyteStateMessage]]] = None,
) -> Iterator[AirbyteMessage]:

# ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent
# streams must be saved so that they can be removed from the catalog before starting synchronous streams
if self._concurrent_streams:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def _get_request_options(self, option_type: RequestOptionType, stream_slice: Opt
self._partition_field_start.eval(self.config)
)
if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self._partition_field_end.eval(self.config)) # type: ignore # field_name is always casted to an interpolated string
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(
self._partition_field_end.eval(self.config)
) # type: ignore # field_name is always casted to an interpolated string
return options

def should_be_synced(self, record: Record) -> bool:
Expand Down
Loading

0 comments on commit aded694

Please sign in to comment.