Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bazarnov committed Dec 19, 2024
1 parent 216cd43 commit 6995005
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 19 deletions.
32 changes: 20 additions & 12 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,32 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""

token_refresh_endpoint: Union[InterpolatedString, str]
client_id: Union[InterpolatedString, str]
client_secret: Union[InterpolatedString, str]
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
refresh_token: Optional[Union[InterpolatedString, str]] = None
scopes: Optional[List[str]] = None
token_expiry_date: Optional[Union[InterpolatedString, str]] = None
_token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None)
token_expiry_date_format: Optional[str] = None
token_expiry_is_time_of_expiration: bool = False
access_token_name: Union[InterpolatedString, str] = "access_token"
access_token_value: Optional[str] = None
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_request_body: Optional[Mapping[str, Any]] = None
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
self._token_refresh_endpoint = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
if self.token_refresh_endpoint is not None:
self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
else:
self._token_refresh_endpoint = None
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
if self.refresh_token is not None:
Expand Down Expand Up @@ -92,20 +96,24 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.token_expiry_date
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
)
self._access_token: Optional[str] = None # access_token is initialized by a setter
self._access_token: Optional[str] = (
self.access_token_value if self.access_token_value else None
)

if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
raise ValueError(
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
)

def get_token_refresh_endpoint(self) -> str:
refresh_token: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token:
raise ValueError(
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
)
return refresh_token
def get_token_refresh_endpoint(self) -> Optional[str]:
if self._token_refresh_endpoint is not None:
refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token_endpoint:
raise ValueError(
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
)
return refresh_token_endpoint
return None

def get_client_id(self) -> str:
client_id: str = self._client_id.eval(self.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,6 @@ definitions:
- type
- client_id
- client_secret
- token_refresh_endpoint
properties:
type:
type: string
Expand Down Expand Up @@ -1060,6 +1059,12 @@ definitions:
default: "access_token"
examples:
- access_token
access_token_value:
title: Access Token Value
description: The value of the access_token to bypass the token refreshing using `refresh_token`.
type: string
examples:
- secret_access_token_value
expires_in_name:
title: Token Expiry Property Name
description: The name of the property which contains the expiry date in the response from the token refresh endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Token",
)
token_refresh_endpoint: str = Field(
...,
token_refresh_endpoint: Optional[str] = Field(
None,
description="The full URL to call to obtain a new access token.",
examples=["https://connect.squareup.com/oauth2/token"],
title="Token Refresh Endpoint",
Expand All @@ -501,6 +501,12 @@ class OAuthAuthenticator(BaseModel):
examples=["access_token"],
title="Access Token Property Name",
)
access_token_value: Optional[str] = Field(
None,
description="The value of the access_token to bypass the token refreshing using `refresh_token`.",
examples=["secret_access_token_value"],
title="Access Token Value",
)
expires_in_name: Optional[str] = Field(
"expires_in",
description="The name of the property which contains the expiry date in the response from the token refresh endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,8 @@ def create_oauth_authenticator(
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
config,
InterpolatedString.create(
model.token_refresh_endpoint, parameters=model.parameters or {}
model.token_refresh_endpoint, # type: ignore
parameters=model.parameters or {},
).eval(config),
access_token_name=InterpolatedString.create(
model.access_token_name or "access_token", parameters=model.parameters or {}
Expand Down Expand Up @@ -1799,6 +1800,7 @@ def create_oauth_authenticator(
# ignore type error because fixing it would have a lot of dependencies, revisit later
return DeclarativeOauth2Authenticator( # type: ignore
access_token_name=model.access_token_name or "access_token",
access_token_value=model.access_token_value,
client_id=model.client_id,
client_secret=model.client_secret,
expires_in_name=model.expires_in_name or "expires_in",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques

def get_auth_header(self) -> Mapping[str, Any]:
"""HTTP header to set on the requests"""
return {"Authorization": f"Bearer {self.get_access_token()}"}
token = (
self.access_token
if (
not self.get_token_refresh_endpoint()
or not self.get_refresh_token()
and self.access_token
)
else self.get_access_token()
)
return {"Authorization": f"Bearer {token}"}

def get_access_token(self) -> str:
"""Returns the access token"""
Expand Down Expand Up @@ -121,7 +130,7 @@ def _get_refresh_access_token_response(self) -> Any:
try:
response = requests.request(
method="POST",
url=self.get_token_refresh_endpoint(),
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
data=self.build_refresh_request_body(),
)
if response.ok:
Expand Down Expand Up @@ -198,7 +207,7 @@ def token_expiry_date_format(self) -> Optional[str]:
return None

@abstractmethod
def get_token_refresh_endpoint(self) -> str:
def get_token_refresh_endpoint(self) -> Optional[str]:
"""Returns the endpoint to refresh the access token"""

@abstractmethod
Expand Down
19 changes: 19 additions & 0 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ def test_refresh_without_refresh_token(self):
}
assert body == expected

def test_get_auth_header_without_refresh_token_and_without_refresh_token_endpoint(self):
"""
Coverred the case when the `access_token_value` is supplied,
without `token_refresh_endpoint` or `refresh_token` provided.
In this case, it's expected to have the `access_token_value` provided to return the permanent `auth header`,
contains the authentication.
"""
oauth = DeclarativeOauth2Authenticator(
access_token_value="my_access_token_value",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
config=config,
parameters={},
grant_type="client_credentials",
)

assert oauth.get_auth_header() == {"Authorization": "Bearer my_access_token_value"}

def test_error_on_refresh_token_grant_without_refresh_token(self):
"""
Should throw an error if grant_type refresh_token is configured without refresh_token.
Expand Down

0 comments on commit 6995005

Please sign in to comment.