Skip to content

Commit

Permalink
FCM v1: use async version of google-auth and add HTTP proxy support (#…
Browse files Browse the repository at this point in the history
…372)

* FCM v1: use async version of google-auth and add HTTP proxy support

* Fix test

* Add changelog

* Address comments

* lint

* Add tests

* Remove print

* Fix test
  • Loading branch information
MatMaul authored May 18, 2024
1 parent 534d845 commit bb62b17
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.d/372.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FCM v1: use async version of google-auth and add HTTP proxy support.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ dependencies = [
"attrs>=19.2.0",
"cryptography>=2.6.1",
"idna>=2.8",
"google-auth>=2.27.0",
"google-auth[aiohttp]>=2.27.0",
"jaeger-client>=4.0.0",
"matrix-common==1.3.0",
"opentracing>=2.2.0",
Expand All @@ -104,6 +104,7 @@ dev = [
"mypy-zope==1.0.1",
"towncrier",
"tox",
"google-auth-stubs==0.2.0",
"types-opentracing>=2.4.2",
"types-pyOpenSSL",
"types-PyYAML",
Expand Down
79 changes: 55 additions & 24 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
import time
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple

import google.auth.transport.requests
from google.oauth2 import service_account
# We are using an unstable async google-auth API, but it's there since 3+ years
# https://github.com/googleapis/google-auth-library-python/issues/613
import aiohttp
import google.auth.transport._aiohttp_requests
from google.auth._default_async import load_credentials_from_file
from google.oauth2._credentials_async import Credentials
from opentracing import Span, logs, tags
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.defer import DeferredSemaphore
from twisted.internet.defer import Deferred, DeferredSemaphore
from twisted.web.client import FileBodyProducer, HTTPConnectionPool, readBody
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
Expand Down Expand Up @@ -180,10 +186,33 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
"Must configure `project_id` when using FCM api v1",
)

self.service_account_file = self.get_config("service_account_file", str)
if self.api_version is APIVersion.V1 and not self.service_account_file:
raise PushkinSetupException(
"Must configure `service_account_file` when using FCM api v1",
self.credentials: Optional[Credentials] = None

if self.api_version is APIVersion.V1:
self.service_account_file = self.get_config("service_account_file", str)
if not self.service_account_file:
raise PushkinSetupException(
"Must configure `service_account_file` when using FCM api v1",
)
try:
self.credentials, _ = load_credentials_from_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
except google.auth.exceptions.DefaultCredentialsError as e:
raise PushkinSetupException(
f"`service_account_file` must be valid: {str(e)}",
)

session = None
if proxy_url:
# `ClientSession` can't directly take the proxy URL, so we need to
# set the usual env var and use `trust_env=True`
os.environ["HTTPS_PROXY"] = proxy_url
session = aiohttp.ClientSession(trust_env=True, auto_decompress=False)

self.google_auth_request = google.auth.transport._aiohttp_requests.Request(
session=session
)

# Use the fcm_options config dictionary as a foundation for the body;
Expand Down Expand Up @@ -464,21 +493,26 @@ def _handle_v1_response(
f"Unknown GCM response code {response.code}"
)

def _get_access_token(self) -> str:
"""Retrieve a valid access token that can be used to authorize requests.
async def _get_auth_header(self) -> str:
"""Retrieve the auth header that can be used to authorize requests.
:return: Access token.
:return: Needed content of the `Authorization` header
"""
# TODO: Should we use the environment variable approach instead?
# export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
# credentials, project = google.auth.default(scopes=AUTH_SCOPES)
credentials = service_account.Credentials.from_service_account_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
if self.api_version is APIVersion.Legacy:
return "key=%s" % (self.api_key,)
else:
assert self.credentials is not None
await self._refresh_credentials()
return "Bearer %s" % self.credentials.token

async def _refresh_credentials(self) -> None:
assert self.credentials is not None
if not self.credentials.valid:
await Deferred.fromFuture(
asyncio.ensure_future(
self.credentials.refresh(self.google_auth_request)
)
)

async def _dispatch_notification_unlimited(
self, n: Notification, device: Device, context: NotificationContext
Expand Down Expand Up @@ -532,10 +566,7 @@ async def _dispatch_notification_unlimited(
"Content-Type": ["application/json"],
}

if self.api_version == APIVersion.Legacy:
headers["Authorization"] = ["key=%s" % (self.api_key,)]
elif self.api_version is APIVersion.V1:
headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)]
headers["Authorization"] = [await self._get_auth_header()]

body = self.base_request_body.copy()
body["data"] = data
Expand Down
55 changes: 51 additions & 4 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple
from unittest.mock import MagicMock

from sygnal.gcmpushkin import GcmPushkin
from sygnal.gcmpushkin import APIVersion, GcmPushkin

from tests import testutils
from tests.testutils import DummyResponse
Expand Down Expand Up @@ -79,6 +80,21 @@
}


class TestCredentials:
def __init__(self) -> None:
self.valid = False

@property
def token(self) -> str:
if self.valid:
return "myaccesstoken"
else:
raise Exception()

async def refresh(self, request: Any) -> None:
self.valid = True


class TestGcmPushkin(GcmPushkin):
"""
A GCM pushkin with the ability to make HTTP requests removed and instead
Expand All @@ -92,6 +108,8 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]):
self.last_request_body: Dict[str, Any] = {}
self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type]
self.num_requests = 0
if self.api_version is APIVersion.V1:
self.credentials = TestCredentials() # type: ignore[assignment]

def preload_with_response(
self, code: int, response_payload: Dict[str, Any]
Expand All @@ -110,8 +128,27 @@ async def _perform_http_request( # type: ignore[override]
self.num_requests += 1
return self.preloaded_response, json.dumps(self.preloaded_response_payload)

def _get_access_token(self) -> str:
return "token"
async def _refresh_credentials(self) -> None:
assert self.credentials is not None
if not self.credentials.valid:
await self.credentials.refresh(self.google_auth_request)


FAKE_SERVICE_ACCOUNT_FILE = b"""
{
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
"private_key": "-----BEGIN PRIVATE KEY-----\\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC0PwE6TeTHjD5R\\nY2nOw1rsTgQZ38LCR2CLtx36n+LUkgej/9b+fwC88oKIqJKjUwn43JEOhf4rbA/a\\nqo4jVoLgv754G5+7Glfarr3/rqg+AVT75x6J5DRvhIYpDXwMIUqLAAbfk3TTFNJn\\n2ctrkBF2ZP9p3mzZ3NRjU63Wbf3LBpRqs8jdFEQu8JAecG8VKV1mboJIXG3hwqFN\\nJmcpC/+sWaxB5iMgSqy0w/rGFs6ZbZF6D10XYvf40lEEk9jQIovT+QD4+6GTlroT\\nbOk8uIwxFQcwMFpXj4MktqVNSNyiuuttptIvBWcMWHlaabXrR89vqUFe1g1Jx4GL\\nCF89RrcLAgMBAAECggEAPUYZ3b8zId78JGDeTEq+8wwGeuFFbRQkrvpeN5/41Xib\\nHlZPuQ5lqtXqKBjeWKVXA4G/0icc45gFv7kxPrQfI9YrItuJLmrjKNU0g+HVEdcU\\nE9pa2Fd6t9peXUBXRixfEee9bm3LTiKK8IDqlTNRrGTjKxNQ/7MBhI6izv1vRH/x\\n8i0o1xxNdqstHZ9wBFKYO9w8UQjtfzckkBNDLkaJ/WN0BoRubmUiV1+KwAyyBr6O\\nRnnZ9Tvy8VraSNSdJhX36ai36y18/sT6PWOp99zHYuDyz89KIz1la/fT9eSoR0Jy\\nYePmTEi+9pWhvtpAkqJkRxe5IDz71JVsQ07KoVfzaQKBgQDzKKUd/0ujhv/B9MQf\\nHcwSeWu/XnQ4hlcwz8dTWQjBV8gv9l4yBj9Pra62rg/tQ7b5XKMt6lv/tWs1IpdA\\neMsySY4972VPrmggKXgCnyKckDUYydNtHAIj9buo6AV8rONaneYnGv5wpSsf3q2c\\nOZrkamRgbBkI+B2mZ2obH1oVlQKBgQC9w9HkrDMvZ5L/ilZmpsvoHNFlQwmDgNlN\\n0ej5QGID5rljRM3CcLNHdyQiKqvLA9MCpPEXb2vVJPdmquD12A7a9s0OwxB/dtOD\\nykofcTY0ZHEM1HEyYJGmdK4FvZuNU4o2/D268dePjtj1Xw3c5fs0bcDiGQMtjWlz\\n5hjBzMsyHwKBgGjrIsPcwlBfEcAo0u7yNnnKNnmuUcuJ+9kt7j3Cbwqty80WKvK+\\ny1agBIECfhDMZQkXtbk8JFIjf4y/zi+db1/VaTDEORy2jmtCOWw4KgEQIDj/7OBp\\nc2r8vupUovl2x+rzsrkw5pTIT+FCffqoyHLCjWkle2/pTzHb8Waekoo5AoGAbELk\\nYy5uwTO45Hr60fOEzzZpq/iz28dNshz4agL2KD2gNGcTcEO1tCbfgXKQsfDLmG2b\\ncgBKJ77AOl1wnDEYQIme8TYOGnojL8Pfx9Jh10AaUvR8Y/49+hYFFhdXQCiR6M69\\nNQM2NJuNYWdKVGUMjJu0+AjHDFzp9YonQ6Ffp4cCgYEAmVALALCjU9GjJymgJ0lx\\nD9LccVHMwf9NmR/sMg0XNePRbCEcMDHKdtVJ1zPGS5txuxY3sRb/tDpv7TfuitrU\\nAw0/2ooMzunaoF/HXo+C/+t+pfuqPqLK4sCCyezUlMfCcaPdwXN2FmbgsaFHfe7I\\n7sGEnS/d8wEgydMiptJEf9s=\\n-----END PRIVATE KEY-----\\n",
"client_email": "firebase-adminsdk@project_id.iam.gserviceaccount.com",
"client_id": "client_id",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk%40project_id.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}
"""


class GcmTestCase(testutils.TestCase):
Expand All @@ -128,11 +165,14 @@ def config_setup(self, config: Dict[str, Any]) -> None:
"api_key": "kii",
"fcm_options": {"content_available": True, "mutable_content": True},
}
self.service_account_file = tempfile.NamedTemporaryFile()
self.service_account_file.write(FAKE_SERVICE_ACCOUNT_FILE)
self.service_account_file.flush()
config["apps"]["com.example.gcm.apiv1"] = {
"type": "tests.test_gcm.TestGcmPushkin",
"api_version": "v1",
"project_id": "example_project",
"service_account_file": "/path/to/file.json",
"service_account_file": self.service_account_file.name,
"fcm_options": {
"apns": {
"payload": {
Expand All @@ -146,6 +186,9 @@ def config_setup(self, config: Dict[str, Any]) -> None:
},
}

def tearDown(self) -> None:
self.service_account_file.close()

def get_test_pushkin(self, name: str) -> TestGcmPushkin:
pushkin = self.sygnal.pushkins[name]
assert isinstance(pushkin, TestGcmPushkin)
Expand Down Expand Up @@ -260,6 +303,10 @@ def test_expected_api_v1(self) -> None:
)

self.assertEqual(resp, {"rejected": []})
assert notification_req[3] is not None
self.assertEqual(
notification_req[3].get("Authorization"), ["Bearer myaccesstoken"]
)

def test_expected_with_default_payload(self) -> None:
"""
Expand Down

0 comments on commit bb62b17

Please sign in to comment.