From 6995005b65ec016ea00149a509d02e0817ed2ef2 Mon Sep 17 00:00:00 2001 From: Oleksandr Bazarnov Date: Thu, 19 Dec 2024 17:27:01 +0200 Subject: [PATCH 1/2] fix --- airbyte_cdk/sources/declarative/auth/oauth.py | 32 ++++++++++++------- .../declarative_component_schema.yaml | 7 +++- .../models/declarative_component_schema.py | 10 ++++-- .../parsers/model_to_component_factory.py | 4 ++- .../requests_native_auth/abstract_oauth.py | 15 +++++++-- .../sources/declarative/auth/test_oauth.py | 19 +++++++++++ 6 files changed, 68 insertions(+), 19 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 8ec671f3e..e11d74db0 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -43,11 +43,11 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut message_repository (MessageRepository): the message repository used to emit logs on HTTP requests """ - token_refresh_endpoint: Union[InterpolatedString, str] client_id: Union[InterpolatedString, str] client_secret: Union[InterpolatedString, str] config: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] + token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None refresh_token: Optional[Union[InterpolatedString, str]] = None scopes: Optional[List[str]] = None token_expiry_date: Optional[Union[InterpolatedString, str]] = None @@ -55,6 +55,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut token_expiry_date_format: Optional[str] = None token_expiry_is_time_of_expiration: bool = False access_token_name: Union[InterpolatedString, str] = "access_token" + access_token_value: Optional[str] = None expires_in_name: Union[InterpolatedString, str] = "expires_in" refresh_request_body: Optional[Mapping[str, Any]] = None grant_type: Union[InterpolatedString, str] = "refresh_token" @@ -62,9 +63,12 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut def __post_init__(self, parameters: Mapping[str, Any]) -> None: super().__init__() - self._token_refresh_endpoint = InterpolatedString.create( - self.token_refresh_endpoint, parameters=parameters - ) + if self.token_refresh_endpoint is not None: + self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create( + self.token_refresh_endpoint, parameters=parameters + ) + else: + self._token_refresh_endpoint = None self._client_id = InterpolatedString.create(self.client_id, parameters=parameters) self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters) if self.refresh_token is not None: @@ -92,20 +96,24 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.token_expiry_date else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints ) - self._access_token: Optional[str] = None # access_token is initialized by a setter + self._access_token: Optional[str] = ( + self.access_token_value if self.access_token_value else None + ) if self.get_grant_type() == "refresh_token" and self._refresh_token is None: raise ValueError( "OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`" ) - def get_token_refresh_endpoint(self) -> str: - refresh_token: str = self._token_refresh_endpoint.eval(self.config) - if not refresh_token: - raise ValueError( - "OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter" - ) - return refresh_token + def get_token_refresh_endpoint(self) -> Optional[str]: + if self._token_refresh_endpoint is not None: + refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config) + if not refresh_token_endpoint: + raise ValueError( + "OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter" + ) + return refresh_token_endpoint + return None def get_client_id(self) -> str: client_id: str = self._client_id.eval(self.config) diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 461cfa764..a4842eff9 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1021,7 +1021,6 @@ definitions: - type - client_id - client_secret - - token_refresh_endpoint properties: type: type: string @@ -1060,6 +1059,12 @@ definitions: default: "access_token" examples: - access_token + access_token_value: + title: Access Token Value + description: The value of the access_token to bypass the token refreshing using `refresh_token`. + type: string + examples: + - secret_access_token_value expires_in_name: title: Token Expiry Property Name description: The name of the property which contains the expiry date in the response from the token refresh endpoint. diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 2b4bba030..8a3f0862e 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -489,8 +489,8 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Token", ) - token_refresh_endpoint: str = Field( - ..., + token_refresh_endpoint: Optional[str] = Field( + None, description="The full URL to call to obtain a new access token.", examples=["https://connect.squareup.com/oauth2/token"], title="Token Refresh Endpoint", @@ -501,6 +501,12 @@ class OAuthAuthenticator(BaseModel): examples=["access_token"], title="Access Token Property Name", ) + access_token_value: Optional[str] = Field( + None, + description="The value of the access_token to bypass the token refreshing using `refresh_token`.", + examples=["secret_access_token_value"], + title="Access Token Value", + ) expires_in_name: Optional[str] = Field( "expires_in", description="The name of the property which contains the expiry date in the response from the token refresh endpoint.", diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 215d6fff9..4bc2f05b9 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1765,7 +1765,8 @@ def create_oauth_authenticator( return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore config, InterpolatedString.create( - model.token_refresh_endpoint, parameters=model.parameters or {} + model.token_refresh_endpoint, # type: ignore + parameters=model.parameters or {}, ).eval(config), access_token_name=InterpolatedString.create( model.access_token_name or "access_token", parameters=model.parameters or {} @@ -1799,6 +1800,7 @@ def create_oauth_authenticator( # ignore type error because fixing it would have a lot of dependencies, revisit later return DeclarativeOauth2Authenticator( # type: ignore access_token_name=model.access_token_name or "access_token", + access_token_value=model.access_token_value, client_id=model.client_id, client_secret=model.client_secret, expires_in_name=model.expires_in_name or "expires_in", diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 01c9d60d0..a7590b88f 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -54,7 +54,16 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques def get_auth_header(self) -> Mapping[str, Any]: """HTTP header to set on the requests""" - return {"Authorization": f"Bearer {self.get_access_token()}"} + token = ( + self.access_token + if ( + not self.get_token_refresh_endpoint() + or not self.get_refresh_token() + and self.access_token + ) + else self.get_access_token() + ) + return {"Authorization": f"Bearer {token}"} def get_access_token(self) -> str: """Returns the access token""" @@ -121,7 +130,7 @@ def _get_refresh_access_token_response(self) -> Any: try: response = requests.request( method="POST", - url=self.get_token_refresh_endpoint(), + url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. data=self.build_refresh_request_body(), ) if response.ok: @@ -198,7 +207,7 @@ def token_expiry_date_format(self) -> Optional[str]: return None @abstractmethod - def get_token_refresh_endpoint(self) -> str: + def get_token_refresh_endpoint(self) -> Optional[str]: """Returns the endpoint to refresh the access token""" @abstractmethod diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index bce87bab2..c77f2bf84 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -129,6 +129,25 @@ def test_refresh_without_refresh_token(self): } assert body == expected + def test_get_auth_header_without_refresh_token_and_without_refresh_token_endpoint(self): + """ + Coverred the case when the `access_token_value` is supplied, + without `token_refresh_endpoint` or `refresh_token` provided. + + In this case, it's expected to have the `access_token_value` provided to return the permanent `auth header`, + contains the authentication. + """ + oauth = DeclarativeOauth2Authenticator( + access_token_value="my_access_token_value", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + config=config, + parameters={}, + grant_type="client_credentials", + ) + + assert oauth.get_auth_header() == {"Authorization": "Bearer my_access_token_value"} + def test_error_on_refresh_token_grant_without_refresh_token(self): """ Should throw an error if grant_type refresh_token is configured without refresh_token. From 028f0c16a222a6b56d4e89b17ef63d3c3dead958 Mon Sep 17 00:00:00 2001 From: Oleksandr Bazarnov Date: Thu, 19 Dec 2024 18:54:09 +0200 Subject: [PATCH 2/2] corrected after the review --- airbyte_cdk/sources/declarative/auth/oauth.py | 11 +++++++++-- unit_tests/sources/declarative/auth/test_oauth.py | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index e11d74db0..f3ba528ac 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -55,7 +55,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut token_expiry_date_format: Optional[str] = None token_expiry_is_time_of_expiration: bool = False access_token_name: Union[InterpolatedString, str] = "access_token" - access_token_value: Optional[str] = None + access_token_value: Optional[Union[InterpolatedString, str]] = None expires_in_name: Union[InterpolatedString, str] = "expires_in" refresh_request_body: Optional[Mapping[str, Any]] = None grant_type: Union[InterpolatedString, str] = "refresh_token" @@ -96,8 +96,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.token_expiry_date else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints ) + if self.access_token_value is not None: + self._access_token_value = InterpolatedString.create( + self.access_token_value, parameters=parameters + ).eval(self.config) + else: + self._access_token_value = None + self._access_token: Optional[str] = ( - self.access_token_value if self.access_token_value else None + self._access_token_value if self.access_token_value else None ) if self.get_grant_type() == "refresh_token" and self._refresh_token is None: diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index c77f2bf84..4130a9dc8 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -26,6 +26,7 @@ "custom_field": "in_outbound_request", "another_field": "exists_in_body", "grant_type": "some_grant_type", + "access_token": "some_access_token", } parameters = {"refresh_token": "some_refresh_token"} @@ -138,15 +139,14 @@ def test_get_auth_header_without_refresh_token_and_without_refresh_token_endpoin contains the authentication. """ oauth = DeclarativeOauth2Authenticator( - access_token_value="my_access_token_value", + access_token_value="{{ config['access_token'] }}", client_id="{{ config['client_id'] }}", client_secret="{{ config['client_secret'] }}", config=config, parameters={}, grant_type="client_credentials", ) - - assert oauth.get_auth_header() == {"Authorization": "Bearer my_access_token_value"} + assert oauth.get_auth_header() == {"Authorization": "Bearer some_access_token"} def test_error_on_refresh_token_grant_without_refresh_token(self): """