Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: (DeclarativeOAuthFlow) - allow DeclarativeOauth2Authenticator to use access_token directly when no token_refresh_endpoint or refresh_token values are provided. #182

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,32 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""

token_refresh_endpoint: Union[InterpolatedString, str]
client_id: Union[InterpolatedString, str]
client_secret: Union[InterpolatedString, str]
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
refresh_token: Optional[Union[InterpolatedString, str]] = None
scopes: Optional[List[str]] = None
token_expiry_date: Optional[Union[InterpolatedString, str]] = None
_token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None)
token_expiry_date_format: Optional[str] = None
token_expiry_is_time_of_expiration: bool = False
access_token_name: Union[InterpolatedString, str] = "access_token"
access_token_value: Optional[Union[InterpolatedString, str]] = None
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_request_body: Optional[Mapping[str, Any]] = None
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
self._token_refresh_endpoint = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
if self.token_refresh_endpoint is not None:
self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
else:
self._token_refresh_endpoint = None
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
if self.refresh_token is not None:
Expand Down Expand Up @@ -92,20 +96,31 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.token_expiry_date
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
)
self._access_token: Optional[str] = None # access_token is initialized by a setter
if self.access_token_value is not None:
self._access_token_value = InterpolatedString.create(
self.access_token_value, parameters=parameters
).eval(self.config)
else:
self._access_token_value = None

self._access_token: Optional[str] = (
self._access_token_value if self.access_token_value else None
)

if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
raise ValueError(
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
)

def get_token_refresh_endpoint(self) -> str:
refresh_token: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token:
raise ValueError(
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
)
return refresh_token
def get_token_refresh_endpoint(self) -> Optional[str]:
if self._token_refresh_endpoint is not None:
refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token_endpoint:
raise ValueError(
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
)
return refresh_token_endpoint
return None

def get_client_id(self) -> str:
client_id: str = self._client_id.eval(self.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,6 @@ definitions:
- type
- client_id
- client_secret
- token_refresh_endpoint
properties:
type:
type: string
Expand Down Expand Up @@ -1060,6 +1059,12 @@ definitions:
default: "access_token"
examples:
- access_token
access_token_value:
title: Access Token Value
description: The value of the access_token to bypass the token refreshing using `refresh_token`.
type: string
examples:
- secret_access_token_value
expires_in_name:
title: Token Expiry Property Name
description: The name of the property which contains the expiry date in the response from the token refresh endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Token",
)
token_refresh_endpoint: str = Field(
...,
token_refresh_endpoint: Optional[str] = Field(
None,
description="The full URL to call to obtain a new access token.",
examples=["https://connect.squareup.com/oauth2/token"],
title="Token Refresh Endpoint",
Expand All @@ -501,6 +501,12 @@ class OAuthAuthenticator(BaseModel):
examples=["access_token"],
title="Access Token Property Name",
)
access_token_value: Optional[str] = Field(
None,
description="The value of the access_token to bypass the token refreshing using `refresh_token`.",
examples=["secret_access_token_value"],
title="Access Token Value",
)
expires_in_name: Optional[str] = Field(
"expires_in",
description="The name of the property which contains the expiry date in the response from the token refresh endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,8 @@ def create_oauth_authenticator(
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
config,
InterpolatedString.create(
model.token_refresh_endpoint, parameters=model.parameters or {}
model.token_refresh_endpoint, # type: ignore
parameters=model.parameters or {},
).eval(config),
access_token_name=InterpolatedString.create(
model.access_token_name or "access_token", parameters=model.parameters or {}
Expand Down Expand Up @@ -1834,6 +1835,7 @@ def create_oauth_authenticator(
# ignore type error because fixing it would have a lot of dependencies, revisit later
return DeclarativeOauth2Authenticator( # type: ignore
access_token_name=model.access_token_name or "access_token",
access_token_value=model.access_token_value,
client_id=model.client_id,
client_secret=model.client_secret,
expires_in_name=model.expires_in_name or "expires_in",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques

def get_auth_header(self) -> Mapping[str, Any]:
"""HTTP header to set on the requests"""
return {"Authorization": f"Bearer {self.get_access_token()}"}
token = (
self.access_token
if (
not self.get_token_refresh_endpoint()
or not self.get_refresh_token()
and self.access_token
)
else self.get_access_token()
)
return {"Authorization": f"Bearer {token}"}

def get_access_token(self) -> str:
"""Returns the access token"""
Expand Down Expand Up @@ -121,7 +130,7 @@ def _get_refresh_access_token_response(self) -> Any:
try:
response = requests.request(
method="POST",
url=self.get_token_refresh_endpoint(),
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
data=self.build_refresh_request_body(),
)
if response.ok:
Expand Down Expand Up @@ -198,7 +207,7 @@ def token_expiry_date_format(self) -> Optional[str]:
return None

@abstractmethod
def get_token_refresh_endpoint(self) -> str:
def get_token_refresh_endpoint(self) -> Optional[str]:
"""Returns the endpoint to refresh the access token"""

@abstractmethod
Expand Down
19 changes: 19 additions & 0 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"custom_field": "in_outbound_request",
"another_field": "exists_in_body",
"grant_type": "some_grant_type",
"access_token": "some_access_token",
}
parameters = {"refresh_token": "some_refresh_token"}

Expand Down Expand Up @@ -129,6 +130,24 @@ def test_refresh_without_refresh_token(self):
}
assert body == expected

def test_get_auth_header_without_refresh_token_and_without_refresh_token_endpoint(self):
"""
Coverred the case when the `access_token_value` is supplied,
without `token_refresh_endpoint` or `refresh_token` provided.

In this case, it's expected to have the `access_token_value` provided to return the permanent `auth header`,
contains the authentication.
"""
oauth = DeclarativeOauth2Authenticator(
access_token_value="{{ config['access_token'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
config=config,
parameters={},
grant_type="client_credentials",
)
assert oauth.get_auth_header() == {"Authorization": "Bearer some_access_token"}

def test_error_on_refresh_token_grant_without_refresh_token(self):
"""
Should throw an error if grant_type refresh_token is configured without refresh_token.
Expand Down
Loading