Skip to content

Commit

Permalink
Merge pull request #458 from supertokens/feat/network-interceptor-hook
Browse files Browse the repository at this point in the history
feat: added network interceptor hook
  • Loading branch information
rishabhpoddar authored Nov 9, 2023
2 parents 4d73959 + 2961c44 commit 0011932
Show file tree
Hide file tree
Showing 36 changed files with 718 additions and 152 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.8] - 2023-11-7

### Added

- Added `network_interceptor` to the `supertokens_config` in `init`.
- This can be used to capture/modify all the HTTP requests sent to the core.
- Solves the issue - https://github.com/supertokens/supertokens-core/issues/865

### Fixes
- The sync functions `create_user_id_mapping` and `delete_user_id_mapping` now take the `force` parameter as an optional argument, just like their async counterparts.
- Functions `get_users_oldest_first`, `get_users_newest_first`, `get_user_count`, `delete_user`, `create_user_id_mapping`, `get_user_id_mapping`, `delete_user_id_mapping` and `update_or_delete_user_id_mapping_info` now accept `user_context` as an optional argument.
- Fixed the dependencies in the example apps
- Example apps will now fetch the latest version of the frameworks

## [0.16.7] - 2023-11-2

- Added `debug` flag in `init()`. If set to `True`, debug logs will be printed.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
Django==4.0.4
django-cors-headers==3.12.0
python-dotenv==0.19.2
supertokens-python
django-cors-headers
python-dotenv
supertokens-python
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fastapi==0.68.1
uvicorn==0.16.0
python-dotenv==0.19.2
fastapi
uvicorn
python-dotenv
supertokens-python
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
flask==2.0.1
flask_cors==3.0.10
python-dotenv==0.19.2
flask
flask_cors
python-dotenv
supertokens-python
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.7",
version="0.16.8",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
50 changes: 39 additions & 11 deletions supertokens_python/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any

from supertokens_python import Supertokens
from supertokens_python.interfaces import (
Expand All @@ -33,9 +33,16 @@ async def get_users_oldest_first(
pagination_token: Union[str, None] = None,
include_recipe_ids: Union[None, List[str]] = None,
query: Union[None, Dict[str, str]] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> UsersResponse:
return await Supertokens.get_instance().get_users(
tenant_id, "ASC", limit, pagination_token, include_recipe_ids, query
tenant_id,
"ASC",
limit,
pagination_token,
include_recipe_ids,
query,
user_context,
)


Expand All @@ -45,61 +52,82 @@ async def get_users_newest_first(
pagination_token: Union[str, None] = None,
include_recipe_ids: Union[None, List[str]] = None,
query: Union[None, Dict[str, str]] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> UsersResponse:
return await Supertokens.get_instance().get_users(
tenant_id, "DESC", limit, pagination_token, include_recipe_ids, query
tenant_id,
"DESC",
limit,
pagination_token,
include_recipe_ids,
query,
user_context,
)


async def get_user_count(
include_recipe_ids: Union[None, List[str]] = None, tenant_id: Optional[str] = None
include_recipe_ids: Union[None, List[str]] = None,
tenant_id: Optional[str] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> int:
return await Supertokens.get_instance().get_user_count(
include_recipe_ids, tenant_id
include_recipe_ids, tenant_id, user_context
)


async def delete_user(user_id: str) -> None:
return await Supertokens.get_instance().delete_user(user_id)
async def delete_user(
user_id: str, user_context: Optional[Dict[str, Any]] = None
) -> None:
return await Supertokens.get_instance().delete_user(user_id, user_context)


async def create_user_id_mapping(
supertokens_user_id: str,
external_user_id: str,
external_user_id_info: Optional[str] = None,
force: Optional[bool] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> Union[
CreateUserIdMappingOkResult,
UnknownSupertokensUserIDError,
UserIdMappingAlreadyExistsError,
]:
return await Supertokens.get_instance().create_user_id_mapping(
supertokens_user_id, external_user_id, external_user_id_info, force
supertokens_user_id,
external_user_id,
external_user_id_info,
force,
user_context,
)


async def get_user_id_mapping(
user_id: str,
user_id_type: Optional[UserIDTypes] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> Union[GetUserIdMappingOkResult, UnknownMappingError]:
return await Supertokens.get_instance().get_user_id_mapping(user_id, user_id_type)
return await Supertokens.get_instance().get_user_id_mapping(
user_id, user_id_type, user_context
)


async def delete_user_id_mapping(
user_id: str,
user_id_type: Optional[UserIDTypes] = None,
force: Optional[bool] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> DeleteUserIdMappingOkResult:
return await Supertokens.get_instance().delete_user_id_mapping(
user_id, user_id_type, force
user_id, user_id_type, force, user_context
)


async def update_or_delete_user_id_mapping_info(
user_id: str,
user_id_type: Optional[UserIDTypes] = None,
external_user_id_info: Optional[str] = None,
user_context: Optional[Dict[str, Any]] = None,
) -> Union[UpdateOrDeleteUserIdMappingInfoOkResult, UnknownMappingError]:
return await Supertokens.get_instance().update_or_delete_user_id_mapping_info(
user_id, user_id_type, external_user_id_info
user_id, user_id_type, external_user_id_info, user_context
)
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.7"
VERSION = "0.16.8"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
118 changes: 108 additions & 10 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from __future__ import annotations

import asyncio

from json import JSONDecodeError
from os import environ
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple

from httpx import AsyncClient, ConnectTimeout, NetworkError, Response

Expand Down Expand Up @@ -50,6 +49,25 @@ class Querier:
api_version = None
__last_tried_index: int = 0
__hosts_alive_for_testing: Set[str] = set()
network_interceptor: Optional[
Callable[
[
str,
str,
Dict[str, Any],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
],
Tuple[
str,
str,
Dict[str, Any],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
],
]
] = None

def __init__(self, hosts: List[Host], rid_to_core: Union[None, str] = None):
self.__hosts = hosts
Expand Down Expand Up @@ -141,14 +159,37 @@ def get_instance(rid_to_core: Union[str, None] = None):
return Querier(Querier.__hosts, rid_to_core)

@staticmethod
def init(hosts: List[Host], api_key: Union[str, None] = None):
def init(
hosts: List[Host],
api_key: Union[str, None] = None,
network_interceptor: Optional[
Callable[
[
str,
str,
Dict[str, Any],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
],
Tuple[
str,
str,
Dict[str, Any],
Optional[Dict[str, Any]],
Optional[Dict[str, Any]],
],
]
] = None,
):
if not Querier.__init_called:
Querier.__init_called = True
Querier.__hosts = hosts
Querier.__api_key = api_key
Querier.api_version = None
Querier.__last_tried_index = 0
Querier.__hosts_alive_for_testing = set()
Querier.network_interceptor = network_interceptor

async def __get_headers_with_api_version(self, path: NormalisedURLPath):
headers = {API_VERSION_HEADER: await self.get_api_version()}
Expand All @@ -159,17 +200,33 @@ async def __get_headers_with_api_version(self, path: NormalisedURLPath):
return headers

async def send_get_request(
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
self,
path: NormalisedURLPath,
params: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
) -> Dict[str, Any]:
if params is None:
params = {}

async def f(url: str, method: str) -> Response:
headers = await self.__get_headers_with_api_version(path)
nonlocal params
if Querier.network_interceptor is not None:
(
url,
method,
headers,
params,
_,
) = Querier.network_interceptor( # pylint:disable=not-callable
url, method, headers, params, {}, user_context
)

return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
headers=headers,
params=params,
)

Expand All @@ -178,7 +235,8 @@ async def f(url: str, method: str) -> Response:
async def send_post_request(
self,
path: NormalisedURLPath,
data: Union[Dict[str, Any], None] = None,
data: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
test: bool = False,
) -> Dict[str, Any]:
if data is None:
Expand All @@ -195,35 +253,64 @@ async def send_post_request(
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str, method: str) -> Response:
nonlocal headers, data
if Querier.network_interceptor is not None:
(
url,
method,
headers,
_,
data,
) = Querier.network_interceptor( # pylint:disable=not-callable
url, method, headers, {}, data, user_context
)
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
headers=headers,
json=data,
)

return await self.__send_request_helper(path, "POST", f, len(self.__hosts))

async def send_delete_request(
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
self,
path: NormalisedURLPath,
params: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
) -> Dict[str, Any]:
if params is None:
params = {}

async def f(url: str, method: str) -> Response:
headers = await self.__get_headers_with_api_version(path)
nonlocal params
if Querier.network_interceptor is not None:
(
url,
method,
headers,
params,
_,
) = Querier.network_interceptor( # pylint:disable=not-callable
url, method, headers, params, {}, user_context
)
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
headers=headers,
params=params,
)

return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))

async def send_put_request(
self, path: NormalisedURLPath, data: Union[Dict[str, Any], None] = None
self,
path: NormalisedURLPath,
data: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
) -> Dict[str, Any]:
if data is None:
data = {}
Expand All @@ -232,6 +319,17 @@ async def send_put_request(
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str, method: str) -> Response:
nonlocal headers, data
if Querier.network_interceptor is not None:
(
url,
method,
headers,
_,
data,
) = Querier.network_interceptor( # pylint:disable=not-callable
url, method, headers, {}, data, user_context
)
return await self.api_request(url, method, 2, headers=headers, json=data)

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))
Expand Down
Loading

0 comments on commit 0011932

Please sign in to comment.