From 5c24247dbae51238f9845085749b021cb1e27013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikael=20Engstr=C3=B6m?= Date: Thu, 21 Nov 2024 08:38:56 +0100 Subject: [PATCH] Fix #744: Bulk messages with aoiapns not working properly Previous version implemented aioapns in a way that made it hang indefinetly. Especially when receiver list contaned a lot of bad tokens. Having a lot of bad tokens still affects the reliability and transfer speed of notification sets, which is why this fix also deactivate devices for error codes BadDeviceToken and DeviceTokenNotForTopic unlike previous versions. --- push_notifications/apns_async.py | 378 +++++++++++++++++-------------- 1 file changed, 202 insertions(+), 176 deletions(-) diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index a0710d85..7c87344f 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -1,5 +1,6 @@ import asyncio import time + from dataclasses import asdict, dataclass from typing import Awaitable, Callable, Dict, Optional, Union @@ -111,37 +112,7 @@ def asDict(self) -> dict[str, any]: } -class APNsService: - __slots__ = ("client",) - - def __init__( - self, - application_id: str = None, - creds: Credentials = None, - topic: str = None, - err_func: ErrFunc = None, - ): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - self.client = self._create_client( - creds=creds, application_id=application_id, topic=topic, err_func=err_func - ) - - def send_message( - self, - request: NotificationRequest, - ): - loop = asyncio.get_event_loop() - routine = self.client.send_notification(request) - res = loop.run_until_complete(routine) - return res - - def _create_notification_request_from_args( - self, +def _create_notification_request_from_args( registration_id: str, alert: Union[str, Alert], badge: int = None, @@ -155,105 +126,103 @@ def _create_notification_request_from_args( aps_kwargs: dict = {}, message_kwargs: dict = {}, notification_request_kwargs: dict = {}, - ): - if alert is None: - alert = Alert(body="") +): + if alert is None: + alert = Alert(body="") - if loc_key: - if isinstance(alert, str): - alert = Alert(body=alert) - alert.loc_key = loc_key + if loc_key: + if isinstance(alert, str): + alert = Alert(body=alert) + alert.loc_key = loc_key - if isinstance(alert, Alert): - alert = alert.asDict() + if isinstance(alert, Alert): + alert = alert.asDict() - notification_request_kwargs_out = notification_request_kwargs.copy() + notification_request_kwargs_out = notification_request_kwargs.copy() - if expiration is not None: - notification_request_kwargs_out["time_to_live"] = expiration - int( - time.time() - ) - if priority is not None: - notification_request_kwargs_out["priority"] = priority - - if collapse_id is not None: - notification_request_kwargs_out["collapse_key"] = collapse_id - - request = NotificationRequest( - device_token=registration_id, - message={ - "aps": { - "alert": alert, - "badge": badge, - "sound": sound, - "thread-id": thread_id, - **aps_kwargs, - }, - **extra, - **message_kwargs, - }, - **notification_request_kwargs_out, + if expiration is not None: + notification_request_kwargs_out["time_to_live"] = expiration - int( + time.time() ) + if priority is not None: + notification_request_kwargs_out["priority"] = priority + + if collapse_id is not None: + notification_request_kwargs_out["collapse_key"] = collapse_id + + request = NotificationRequest( + device_token=registration_id, + message={ + "aps": { + "alert": alert, + "badge": badge, + "sound": sound, + "thread-id": thread_id, + **aps_kwargs, + }, + **extra, + **message_kwargs, + }, + **notification_request_kwargs_out, + ) - return request + return request - def _create_client( - self, + +def _create_client( creds: Credentials = None, application_id: str = None, topic=None, err_func: ErrFunc = None, - ) -> APNs: - use_sandbox = get_manager().get_apns_use_sandbox(application_id) - if topic is None: - topic = get_manager().get_apns_topic(application_id) - if creds is None: - creds = self._get_credentials(application_id) - - client = APNs( - **asdict(creds), - topic=topic, # Bundle ID - use_sandbox=use_sandbox, - err_func=err_func, - ) - return client - - def _get_credentials(self, application_id): - if not get_manager().has_auth_token_creds(application_id): - # TLS certificate authentication - cert = get_manager().get_apns_certificate(application_id) - return CertificateCredentials( - client_cert=cert, - ) - else: - # Token authentication - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) +) -> APNs: + use_sandbox = get_manager().get_apns_use_sandbox(application_id) + if topic is None: + topic = get_manager().get_apns_topic(application_id) + if creds is None: + creds = _get_credentials(application_id) + + client = APNs( + **asdict(creds), + topic=topic, # Bundle ID + use_sandbox=use_sandbox, + err_func=err_func, + ) + return client -# Public interface +def _get_credentials(application_id): + if not get_manager().has_auth_token_creds(application_id): + # TLS certificate authentication + cert = get_manager().get_apns_certificate(application_id) + return CertificateCredentials( + client_cert=cert, + ) + else: + # Token authentication + keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) def apns_send_message( - registration_id: str, - alert: Union[str, Alert], - application_id: str = None, - creds: Credentials = None, - topic: str = None, - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - err_func: ErrFunc = None, + registration_id: str, + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, ): """ Sends an APNS notification to a single registration_id. @@ -270,50 +239,45 @@ def apns_send_message( :param application_id: The application_id to use :param creds: The credentials to use """ + results = apns_send_bulk_message( + registration_ids=[registration_id], + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + err_func=err_func, + ) - try: - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - - request = apns_service._create_notification_request_from_args( - registration_id, - alert, - badge=badge, - sound=sound, - extra=extra, - expiration=expiration, - thread_id=thread_id, - loc_key=loc_key, - priority=priority, - collapse_id=collapse_id, - ) - res = apns_service.send_message(request) - if not res.is_successful: - if res.description == "Unregistered": - models.APNSDevice.objects.filter( - registration_id=registration_id - ).update(active=False) - raise APNSServerError(status=res.description) - except ConnectionError as e: - raise APNSServerError(status=e.__class__.__name__) + for result in results.values(): + if result == "Success": + return {"results": [result]} + else: + return {"results": [{"error": result}]} def apns_send_bulk_message( - registration_ids: list[str], - alert: Union[str, Alert], - application_id: str = None, - creds: Credentials = None, - topic: str = None, - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - err_func: ErrFunc = None, + registration_ids: list[str], + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, ): """ Sends an APNS notification to one or more registration_ids. @@ -328,17 +292,17 @@ def apns_send_bulk_message( :param application_id: The application_id to use :param creds: The credentials to use """ - - topic = get_manager().get_apns_topic(application_id) - results: Dict[str, str] = {} - inactive_tokens = [] - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - for registration_id in registration_ids: - request = apns_service._create_notification_request_from_args( - registration_id, - alert, + try: + topic = get_manager().get_apns_topic(application_id) + results: Dict[str, str] = {} + inactive_tokens = [] + + responses = asyncio.run(_send_bulk_request( + registration_ids=registration_ids, + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, badge=badge, sound=sound, extra=extra, @@ -347,17 +311,79 @@ def apns_send_bulk_message( loc_key=loc_key, priority=priority, collapse_id=collapse_id, - ) + err_func=err_func, + )) - result = apns_service.send_message(request) - results[registration_id] = ( - "Success" if result.is_successful else result.description - ) - if not result.is_successful and result.description == "Unregistered": - inactive_tokens.append(registration_id) + results = {} + for registration_id, result in responses: + results[registration_id] = ( + "Success" if result.is_successful else result.description + ) + if not result.is_successful and result.description in ["Unregistered", "BadDeviceToken", + "DeviceTokenNotForTopic"]: + inactive_tokens.append(registration_id) + + if len(inactive_tokens) > 0: + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + + return results - if len(inactive_tokens) > 0: - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( - active=False + except ConnectionError as e: + raise APNSServerError(status=e.__class__.__name__) + + +async def _send_bulk_request( + registration_ids: list[str], + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, +): + client = _create_client( + creds=creds, application_id=application_id, topic=topic, err_func=err_func + ) + + requests = [_create_notification_request_from_args( + registration_id, + alert, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + ) for registration_id in registration_ids] + + send_requests = [_send_request(client, request) for request in requests] + return await asyncio.gather(*send_requests) + + +async def _send_request(apns, request): + try: + res = await asyncio.wait_for(apns.send_notification(request), timeout=1) + return request.device_token, res + except asyncio.TimeoutError: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="TimeoutError" + ) + except: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="CommunicationError" ) - return results