From 5f1c1fa8a65256f0b537044a6eb662057ed7e861 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Wed, 4 Oct 2023 19:23:52 +0530 Subject: [PATCH 01/52] Added tokenize_rt as a dev-dependency --- poetry.lock | 145 +++++++++++++++++++------------------------------ pyproject.toml | 1 + 2 files changed, 57 insertions(+), 89 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1510fa0e..a07edf83 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "anyio" version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -27,7 +26,6 @@ trio = ["trio (<0.22)"] name = "async-case" version = "10.1.0" description = "Backport of Python 3.8's unittest.async_case" -category = "dev" optional = false python-versions = "*" files = [ @@ -38,7 +36,6 @@ files = [ name = "certifi" version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -50,7 +47,6 @@ files = [ name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -62,7 +58,6 @@ files = [ name = "coverage" version = "7.2.7" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -135,7 +130,6 @@ toml = ["tomli"] name = "exceptiongroup" version = "1.1.3" description = "Backport of PEP 654 (exception groups)" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -150,7 +144,6 @@ test = ["pytest (>=6)"] name = "execnet" version = "2.0.2" description = "execnet: rapid multi-Python deployment" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -165,7 +158,6 @@ testing = ["hatch", "pre-commit", "pytest", "tox"] name = "flake8" version = "3.9.2" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -183,7 +175,6 @@ pyflakes = ">=2.3.0,<2.4.0" name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -198,7 +189,6 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""} name = "h2" version = "4.1.0" description = "HTTP/2 State-Machine based protocol implementation" -category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -214,7 +204,6 @@ hyperframe = ">=6.0,<7" name = "hpack" version = "4.0.0" description = "Pure-Python HPACK header compression" -category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -226,7 +215,6 @@ files = [ name = "httpcore" version = "0.17.3" description = "A minimal low-level HTTP client." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -238,17 +226,16 @@ files = [ anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = ">=1.0.0,<2.0.0" +sniffio = "==1.*" [package.extras] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (>=1.0.0,<2.0.0)"] +socks = ["socksio (==1.*)"] [[package]] name = "httpx" version = "0.24.1" description = "The next generation HTTP client." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -265,15 +252,14 @@ sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (>=1.0.0,<2.0.0)"] +socks = ["socksio (==1.*)"] [[package]] name = "hyperframe" version = "6.0.1" description = "HTTP/2 framing layer for Python" -category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -285,7 +271,6 @@ files = [ name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -297,7 +282,6 @@ files = [ name = "importlib-metadata" version = "4.13.0" description = "Read metadata from Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -318,7 +302,6 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -330,7 +313,6 @@ files = [ name = "mccabe" version = "0.6.1" description = "McCabe checker, plugin for flake8" -category = "dev" optional = false python-versions = "*" files = [ @@ -342,7 +324,6 @@ files = [ name = "methoddispatch" version = "3.0.2" description = "singledispatch decorator for class methods." -category = "main" optional = false python-versions = "*" files = [ @@ -354,7 +335,6 @@ files = [ name = "mock" version = "4.0.3" description = "Rolling backport of unittest.mock for all Pythons" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -371,7 +351,6 @@ test = ["pytest (<5.4)", "pytest-cov"] name = "msgpack" version = "1.0.5" description = "MessagePack serializer" -category = "main" optional = false python-versions = "*" files = [ @@ -442,21 +421,19 @@ files = [ [[package]] name = "packaging" -version = "23.1" +version = "23.2" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] [[package]] name = "pep8-naming" version = "0.4.1" description = "Check PEP-8 naming conventions, plugin for flake8" -category = "dev" optional = false python-versions = "*" files = [ @@ -468,7 +445,6 @@ files = [ name = "pluggy" version = "1.2.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -487,7 +463,6 @@ testing = ["pytest", "pytest-benchmark"] name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -499,7 +474,6 @@ files = [ name = "pycodestyle" version = "2.7.0" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -511,7 +485,6 @@ files = [ name = "pycrypto" version = "2.6.1" description = "Cryptographic modules for Python." -category = "main" optional = true python-versions = "*" files = [ @@ -520,51 +493,49 @@ files = [ [[package]] name = "pycryptodome" -version = "3.18.0" +version = "3.19.0" description = "Cryptographic library for Python" -category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ - {file = "pycryptodome-3.18.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:d1497a8cd4728db0e0da3c304856cb37c0c4e3d0b36fcbabcc1600f18504fc54"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:928078c530da78ff08e10eb6cada6e0dff386bf3d9fa9871b4bbc9fbc1efe024"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:157c9b5ba5e21b375f052ca78152dd309a09ed04703fd3721dce3ff8ecced148"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-manylinux2014_aarch64.whl", hash = "sha256:d20082bdac9218649f6abe0b885927be25a917e29ae0502eaf2b53f1233ce0c2"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:e8ad74044e5f5d2456c11ed4cfd3e34b8d4898c0cb201c4038fe41458a82ea27"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-win32.whl", hash = "sha256:62a1e8847fabb5213ccde38915563140a5b338f0d0a0d363f996b51e4a6165cf"}, - {file = "pycryptodome-3.18.0-cp27-cp27m-win_amd64.whl", hash = "sha256:16bfd98dbe472c263ed2821284118d899c76968db1a6665ade0c46805e6b29a4"}, - {file = "pycryptodome-3.18.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:7a3d22c8ee63de22336679e021c7f2386f7fc465477d59675caa0e5706387944"}, - {file = "pycryptodome-3.18.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:78d863476e6bad2a592645072cc489bb90320972115d8995bcfbee2f8b209918"}, - {file = "pycryptodome-3.18.0-cp27-cp27mu-manylinux2014_aarch64.whl", hash = "sha256:b6a610f8bfe67eab980d6236fdc73bfcdae23c9ed5548192bb2d530e8a92780e"}, - {file = "pycryptodome-3.18.0-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:422c89fd8df8a3bee09fb8d52aaa1e996120eafa565437392b781abec2a56e14"}, - {file = "pycryptodome-3.18.0-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:9ad6f09f670c466aac94a40798e0e8d1ef2aa04589c29faa5b9b97566611d1d1"}, - {file = "pycryptodome-3.18.0-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:53aee6be8b9b6da25ccd9028caf17dcdce3604f2c7862f5167777b707fbfb6cb"}, - {file = "pycryptodome-3.18.0-cp35-abi3-manylinux2014_aarch64.whl", hash = "sha256:10da29526a2a927c7d64b8f34592f461d92ae55fc97981aab5bbcde8cb465bb6"}, - {file = "pycryptodome-3.18.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f21efb8438971aa16924790e1c3dba3a33164eb4000106a55baaed522c261acf"}, - {file = "pycryptodome-3.18.0-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4944defabe2ace4803f99543445c27dd1edbe86d7d4edb87b256476a91e9ffa4"}, - {file = "pycryptodome-3.18.0-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:51eae079ddb9c5f10376b4131be9589a6554f6fd84f7f655180937f611cd99a2"}, - {file = "pycryptodome-3.18.0-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:83c75952dcf4a4cebaa850fa257d7a860644c70a7cd54262c237c9f2be26f76e"}, - {file = "pycryptodome-3.18.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:957b221d062d5752716923d14e0926f47670e95fead9d240fa4d4862214b9b2f"}, - {file = "pycryptodome-3.18.0-cp35-abi3-win32.whl", hash = "sha256:795bd1e4258a2c689c0b1f13ce9684fa0dd4c0e08680dcf597cf9516ed6bc0f3"}, - {file = "pycryptodome-3.18.0-cp35-abi3-win_amd64.whl", hash = "sha256:b1d9701d10303eec8d0bd33fa54d44e67b8be74ab449052a8372f12a66f93fb9"}, - {file = "pycryptodome-3.18.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:cb1be4d5af7f355e7d41d36d8eec156ef1382a88638e8032215c215b82a4b8ec"}, - {file = "pycryptodome-3.18.0-pp27-pypy_73-win32.whl", hash = "sha256:fc0a73f4db1e31d4a6d71b672a48f3af458f548059aa05e83022d5f61aac9c08"}, - {file = "pycryptodome-3.18.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f022a4fd2a5263a5c483a2bb165f9cb27f2be06f2f477113783efe3fe2ad887b"}, - {file = "pycryptodome-3.18.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:363dd6f21f848301c2dcdeb3c8ae5f0dee2286a5e952a0f04954b82076f23825"}, - {file = "pycryptodome-3.18.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12600268763e6fec3cefe4c2dcdf79bde08d0b6dc1813887e789e495cb9f3403"}, - {file = "pycryptodome-3.18.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4604816adebd4faf8810782f137f8426bf45fee97d8427fa8e1e49ea78a52e2c"}, - {file = "pycryptodome-3.18.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:01489bbdf709d993f3058e2996f8f40fee3f0ea4d995002e5968965fa2fe89fb"}, - {file = "pycryptodome-3.18.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3811e31e1ac3069988f7a1c9ee7331b942e605dfc0f27330a9ea5997e965efb2"}, - {file = "pycryptodome-3.18.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f4b967bb11baea9128ec88c3d02f55a3e338361f5e4934f5240afcb667fdaec"}, - {file = "pycryptodome-3.18.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9c8eda4f260072f7dbe42f473906c659dcbadd5ae6159dfb49af4da1293ae380"}, - {file = "pycryptodome-3.18.0.tar.gz", hash = "sha256:c9adee653fc882d98956e33ca2c1fb582e23a8af7ac82fee75bd6113c55a0413"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:3006c44c4946583b6de24fe0632091c2653d6256b99a02a3db71ca06472ea1e4"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:7c760c8a0479a4042111a8dd2f067d3ae4573da286c53f13cf6f5c53a5c1f631"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:08ce3558af5106c632baf6d331d261f02367a6bc3733086ae43c0f988fe042db"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45430dfaf1f421cf462c0dd824984378bef32b22669f2635cb809357dbaab405"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:a9bcd5f3794879e91970f2bbd7d899780541d3ff439d8f2112441769c9f2ccea"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-win32.whl", hash = "sha256:190c53f51e988dceb60472baddce3f289fa52b0ec38fbe5fd20dd1d0f795c551"}, + {file = "pycryptodome-3.19.0-cp27-cp27m-win_amd64.whl", hash = "sha256:22e0ae7c3a7f87dcdcf302db06ab76f20e83f09a6993c160b248d58274473bfa"}, + {file = "pycryptodome-3.19.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:7822f36d683f9ad7bc2145b2c2045014afdbbd1d9922a6d4ce1cbd6add79a01e"}, + {file = "pycryptodome-3.19.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:05e33267394aad6db6595c0ce9d427fe21552f5425e116a925455e099fdf759a"}, + {file = "pycryptodome-3.19.0-cp27-cp27mu-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:829b813b8ee00d9c8aba417621b94bc0b5efd18c928923802ad5ba4cf1ec709c"}, + {file = "pycryptodome-3.19.0-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:fc7a79590e2b5d08530175823a242de6790abc73638cc6dc9d2684e7be2f5e49"}, + {file = "pycryptodome-3.19.0-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:542f99d5026ac5f0ef391ba0602f3d11beef8e65aae135fa5b762f5ebd9d3bfb"}, + {file = "pycryptodome-3.19.0-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:61bb3ccbf4bf32ad9af32da8badc24e888ae5231c617947e0f5401077f8b091f"}, + {file = "pycryptodome-3.19.0-cp35-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d49a6c715d8cceffedabb6adb7e0cbf41ae1a2ff4adaeec9432074a80627dea1"}, + {file = "pycryptodome-3.19.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e249a784cc98a29c77cea9df54284a44b40cafbfae57636dd2f8775b48af2434"}, + {file = "pycryptodome-3.19.0-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d033947e7fd3e2ba9a031cb2d267251620964705a013c5a461fa5233cc025270"}, + {file = "pycryptodome-3.19.0-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:84c3e4fffad0c4988aef0d5591be3cad4e10aa7db264c65fadbc633318d20bde"}, + {file = "pycryptodome-3.19.0-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:139ae2c6161b9dd5d829c9645d781509a810ef50ea8b657e2257c25ca20efe33"}, + {file = "pycryptodome-3.19.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:5b1986c761258a5b4332a7f94a83f631c1ffca8747d75ab8395bf2e1b93283d9"}, + {file = "pycryptodome-3.19.0-cp35-abi3-win32.whl", hash = "sha256:536f676963662603f1f2e6ab01080c54d8cd20f34ec333dcb195306fa7826997"}, + {file = "pycryptodome-3.19.0-cp35-abi3-win_amd64.whl", hash = "sha256:04dd31d3b33a6b22ac4d432b3274588917dcf850cc0c51c84eca1d8ed6933810"}, + {file = "pycryptodome-3.19.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:8999316e57abcbd8085c91bc0ef75292c8618f41ca6d2b6132250a863a77d1e7"}, + {file = "pycryptodome-3.19.0-pp27-pypy_73-win32.whl", hash = "sha256:a0ab84755f4539db086db9ba9e9f3868d2e3610a3948cbd2a55e332ad83b01b0"}, + {file = "pycryptodome-3.19.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0101f647d11a1aae5a8ce4f5fad6644ae1b22bb65d05accc7d322943c69a74a6"}, + {file = "pycryptodome-3.19.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c1601e04d32087591d78e0b81e1e520e57a92796089864b20e5f18c9564b3fa"}, + {file = "pycryptodome-3.19.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:506c686a1eee6c00df70010be3b8e9e78f406af4f21b23162bbb6e9bdf5427bc"}, + {file = "pycryptodome-3.19.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7919ccd096584b911f2a303c593280869ce1af9bf5d36214511f5e5a1bed8c34"}, + {file = "pycryptodome-3.19.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:560591c0777f74a5da86718f70dfc8d781734cf559773b64072bbdda44b3fc3e"}, + {file = "pycryptodome-3.19.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1cc2f2ae451a676def1a73c1ae9120cd31af25db3f381893d45f75e77be2400"}, + {file = "pycryptodome-3.19.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17940dcf274fcae4a54ec6117a9ecfe52907ed5e2e438fe712fe7ca502672ed5"}, + {file = "pycryptodome-3.19.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d04f5f623a280fbd0ab1c1d8ecbd753193ab7154f09b6161b0f857a1a676c15f"}, + {file = "pycryptodome-3.19.0.tar.gz", hash = "sha256:bc35d463222cdb4dbebd35e0784155c81e161b9284e567e7e933d722e533331e"}, ] [[package]] name = "pyee" version = "9.1.1" description = "A port of node.js's EventEmitter to python." -category = "main" optional = false python-versions = "*" files = [ @@ -579,7 +550,6 @@ typing-extensions = "*" name = "pyflakes" version = "2.3.1" description = "passive checker of Python programs" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -589,14 +559,13 @@ files = [ [[package]] name = "pytest" -version = "7.4.0" +version = "7.4.2" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, - {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, + {file = "pytest-7.4.2-py3-none-any.whl", hash = "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002"}, + {file = "pytest-7.4.2.tar.gz", hash = "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069"}, ] [package.dependencies] @@ -615,7 +584,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "2.12.1" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -635,7 +603,6 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "pytest-flake8" version = "1.1.0" description = "pytest plugin to check FLAKE8 requirements" -category = "dev" optional = false python-versions = "*" files = [ @@ -651,7 +618,6 @@ pytest = ">=3.5" name = "pytest-forked" version = "1.6.0" description = "run tests in isolated forked subprocesses" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -667,7 +633,6 @@ pytest = ">=3.10" name = "pytest-timeout" version = "2.1.0" description = "pytest plugin to abort hanging tests" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -682,7 +647,6 @@ pytest = ">=5.0.0" name = "pytest-xdist" version = "1.34.0" description = "pytest xdist plugin for distributed testing and loop-on-failing modes" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -703,7 +667,6 @@ testing = ["filelock"] name = "respx" version = "0.20.2" description = "A utility for mocking out the Python HTTPX and HTTP Core libraries." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -718,7 +681,6 @@ httpx = ">=0.21.0" name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -730,7 +692,6 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -738,11 +699,21 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "tokenize-rt" +version = "5.0.0" +description = "A wrapper around the stdlib `tokenize` which roundtrips." +optional = false +python-versions = ">=3.7" +files = [ + {file = "tokenize_rt-5.0.0-py2.py3-none-any.whl", hash = "sha256:c67772c662c6b3dc65edf66808577968fb10badfc2042e3027196bed4daf9e5a"}, + {file = "tokenize_rt-5.0.0.tar.gz", hash = "sha256:3160bc0c3e8491312d0485171dea861fc160a240f5f5766b72a1165408d10740"}, +] + [[package]] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" -category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -754,7 +725,6 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -766,7 +736,6 @@ files = [ name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -778,7 +747,6 @@ files = [ name = "websockets" version = "10.4" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -857,7 +825,6 @@ files = [ name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -876,4 +843,4 @@ oldcrypto = ["pycrypto"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "885ad9d7e6a0adc96cae0dcf69a7c8d7af8dbbf3651b7cce29deed789ad581e7" +content-hash = "a6ee4818d5e151e0149c60bb77a2c74aa9f8e676ffd99277af588ad06031c67d" diff --git a/pyproject.toml b/pyproject.toml index 3cb26fb9..1e0a1e78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ respx = "^0.20.0" importlib-metadata = "^4.12" pytest-timeout = "^2.1.0" async-case = { version = "^10.1.0", python = "~3.7" } +tokenize_rt = "*" [build-system] requires = ["poetry-core>=1.0.0"] From ef8c7a70b4b340cd6a21cd62a9a02bc27ffee6c4 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Wed, 4 Oct 2023 20:52:40 +0530 Subject: [PATCH 02/52] Created unasync file to convert async code to sync code --- unasync.py | 199 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 unasync.py diff --git a/unasync.py b/unasync.py new file mode 100644 index 00000000..cf4ac648 --- /dev/null +++ b/unasync.py @@ -0,0 +1,199 @@ +"""Top-level package for unasync.""" + +import collections +import glob +import os +import tokenize as std_tokenize + +import tokenize_rt + +_ASYNC_TO_SYNC = { + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "__aiter__": "__iter__", + "__anext__": "__next__", + "asynccontextmanager": "contextmanager", + "AsyncIterable": "Iterable", + "AsyncIterator": "Iterator", + "AsyncGenerator": "Generator", + # TODO StopIteration is still accepted in Python 2, but the right change + # is 'raise StopAsyncIteration' -> 'return' since we want to use unasynced + # code in Python 3.7+ + "StopAsyncIteration": "StopIteration", +} + + +class Rule: + """A single set of rules for 'unasync'ing file(s)""" + + def __init__(self, fromdir, todir, additional_replacements=None): + self.fromdir = fromdir.replace("/", os.sep) + self.todir = todir.replace("/", os.sep) + + # Add any additional user-defined token replacements to our list. + self.token_replacements = _ASYNC_TO_SYNC.copy() + for key, val in (additional_replacements or {}).items(): + self.token_replacements[key] = val + + def _match(self, filepath): + """Determines if a Rule matches a given filepath and if so + returns a higher comparable value if the match is more specific. + """ + file_segments = [x for x in filepath.split(os.sep) if x] + from_segments = [x for x in self.fromdir.split(os.sep) if x] + len_from_segments = len(from_segments) + + if len_from_segments > len(file_segments): + return False + + for i in range(len(file_segments) - len_from_segments + 1): + if file_segments[i: i + len_from_segments] == from_segments: + return len_from_segments, i + + return False + + def _unasync_file(self, filepath): + with open(filepath, "rb") as f: + encoding, _ = std_tokenize.detect_encoding(f.readline) + + with open(filepath, "rt", encoding=encoding) as f: + tokens = tokenize_rt.src_to_tokens(f.read()) + tokens = self._unasync_tokens(tokens) + result = tokenize_rt.tokens_to_src(tokens) + outfilepath = filepath.replace(self.fromdir, self.todir) + os.makedirs(os.path.dirname(outfilepath), exist_ok=True) + with open(outfilepath, "wb") as f: + f.write(result.encode(encoding)) + + def _unasync_tokens(self, tokens): + skip_next = False + for i, token in enumerate(tokens): + if skip_next: + skip_next = False + continue + + if token.src in ["async", "await"]: + # When removing async or await, we want to skip the following whitespace + # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` + skip_next = True + else: + if token.name == "NAME": + token = token._replace(src=self._unasync_name(token.src)) + elif token.name == "STRING": + left_quote, name, right_quote = ( + token.src[0], + token.src[1:-1], + token.src[-1], + ) + token = token._replace( + src=left_quote + self._unasync_name(name) + right_quote + ) + + yield token + + def _unasync_name(self, name): + if name in self.token_replacements: + return self.token_replacements[name] + # Convert classes prefixed with 'Async' into 'Sync' + elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): + return "Sync" + name[5:] + return name + + +def unasync_files(fpath_list, rules): + for f in fpath_list: + found_rule = None + found_weight = None + + for rule in rules: + weight = rule._match(f) + if weight and (found_weight is None or weight > found_weight): + found_rule = rule + found_weight = weight + + if found_rule: + found_rule._unasync_file(f) + + +Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) + +_ASYNC_TO_SYNC["http"] = "ably.sync.http.paginatedresult" + +src_dir_path = os.path.join(os.getcwd(), "ably", "rest") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "rest") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +os.makedirs(dest_dir_path, exist_ok=True) + +def find_files(dir_path, file_name_regex) -> list[str]: + return glob.glob(os.path.join(dir_path, "*" + file_name_regex)) + + +src_files = find_files(src_dir_path, ".py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + +# round 2 +src_dir_path = os.path.join(os.getcwd(), "ably", "http") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + + +src_files = find_files(src_dir_path, ".py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + +# round 3 + +src_dir_path = os.path.join(os.getcwd(), "ably", "types") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "types") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + + +src_files = find_files(src_dir_path, "presence.py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + + +# class _build_py(orig.build_py): +# """ +# Subclass build_py from setuptools to modify its behavior. +# +# Convert files in _async dir from being asynchronous to synchronous +# and saves them in _sync dir. +# """ +# +# UNASYNC_RULES = (_DEFAULT_RULE,) +# +# def run(self): +# rules = self.UNASYNC_RULES +# +# self._updated_files = [] +# +# # Base class code +# if self.py_modules: +# self.build_modules() +# if self.packages: +# self.build_packages() +# self.build_package_data() +# +# # Our modification! +# unasync_files(self._updated_files, rules) +# +# # Remaining base class code +# self.byte_compile(self.get_outputs(include_bytecode=0)) +# +# def build_module(self, module, module_file, package): +# outfile, copied = super().build_module(module, module_file, package) +# if copied: +# self._updated_files.append(outfile) +# return outfile, copied +# +# +# def cmdclass_build_py(rules=(_DEFAULT_RULE,)): +# """Creates a 'build_py' class for use within 'cmdclass={"build_py": ...}'""" +# +# class _custom_build_py(_build_py): +# UNASYNC_RULES = rules +# +# return _custom_build_py From 86f55308ca688ac8eae04eb93f193375391d6c18 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Wed, 4 Oct 2023 21:22:17 +0530 Subject: [PATCH 03/52] Added counter based tokenization strategy for replacing tokens --- unasync.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/unasync.py b/unasync.py index cf4ac648..71aeecc6 100644 --- a/unasync.py +++ b/unasync.py @@ -65,17 +65,16 @@ def _unasync_file(self, filepath): with open(outfilepath, "wb") as f: f.write(result.encode(encoding)) - def _unasync_tokens(self, tokens): - skip_next = False - for i, token in enumerate(tokens): - if skip_next: - skip_next = False - continue + def _unasync_tokens(self, tokens: list): + new_tokens = [] + token_counter = 0 + while token_counter < len(tokens): + token = tokens[token_counter] if token.src in ["async", "await"]: # When removing async or await, we want to skip the following whitespace # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` - skip_next = True + token_counter = token_counter + 1 else: if token.name == "NAME": token = token._replace(src=self._unasync_name(token.src)) @@ -89,7 +88,34 @@ def _unasync_tokens(self, tokens): src=left_quote + self._unasync_name(name) + right_quote ) - yield token + new_tokens.append(token) + token_counter = token_counter + 1 + + return new_tokens + + # for i, token in enumerate(tokens): + # if skip_next: + # skip_next = False + # continue + # + # if token.src in ["async", "await"]: + # # When removing async or await, we want to skip the following whitespace + # # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` + # skip_next = True + # else: + # if token.name == "NAME": + # token = token._replace(src=self._unasync_name(token.src)) + # elif token.name == "STRING": + # left_quote, name, right_quote = ( + # token.src[0], + # token.src[1:-1], + # token.src[-1], + # ) + # token = token._replace( + # src=left_quote + self._unasync_name(name) + right_quote + # ) + # + # yield token def _unasync_name(self, name): if name in self.token_replacements: From 9626ef987598fb63d748a78d6f91eded96e92111 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Wed, 4 Oct 2023 22:34:40 +0530 Subject: [PATCH 04/52] Added code to replace imports using tokenizer --- unasync.py | 68 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/unasync.py b/unasync.py index 71aeecc6..f4461a38 100644 --- a/unasync.py +++ b/unasync.py @@ -22,6 +22,10 @@ "StopAsyncIteration": "StopIteration", } +_IMPORTS_REPLACE = { + +} + class Rule: """A single set of rules for 'unasync'ing file(s)""" @@ -72,23 +76,26 @@ def _unasync_tokens(self, tokens: list): token = tokens[token_counter] if token.src in ["async", "await"]: - # When removing async or await, we want to skip the following whitespace - # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` - token_counter = token_counter + 1 - else: - if token.name == "NAME": - token = token._replace(src=self._unasync_name(token.src)) - elif token.name == "STRING": - left_quote, name, right_quote = ( - token.src[0], - token.src[1:-1], - token.src[-1], - ) - token = token._replace( - src=left_quote + self._unasync_name(name) + right_quote - ) - - new_tokens.append(token) + token_counter = token_counter + 1 # When removing async or await, we want to skip the following whitespace + continue + elif token.name == "NAME": + if token.src == "from": + if tokens[token_counter + 1].src == " ": + token_counter = self._replace_import(tokens, token_counter, new_tokens) + continue + else: + token = token._replace(src=self._unasync_name(token.src)) + elif token.name == "STRING": + left_quote, name, right_quote = ( + token.src[0], + token.src[1:-1], + token.src[-1], + ) + token = token._replace( + src=left_quote + self._unasync_name(name) + right_quote + ) + + new_tokens.append(token) token_counter = token_counter + 1 return new_tokens @@ -117,6 +124,26 @@ def _unasync_tokens(self, tokens: list): # # yield token + def _replace_import(self, tokens, token_counter, new_tokens: list): + new_tokens.append(tokens[token_counter]) + new_tokens.append(tokens[token_counter + 1]) + + full_lib_name = '' + lib_name_counter = token_counter + 2 + while True: + if tokens[lib_name_counter].src == " ": + break + full_lib_name = full_lib_name + tokens[lib_name_counter].src + lib_name_counter = lib_name_counter + 1 + + if full_lib_name in _IMPORTS_REPLACE: + for lib_name_token in _IMPORTS_REPLACE[full_lib_name].split("."): + new_tokens.append(tokenize_rt.Token("NAME", lib_name_token)) + new_tokens.append(tokenize_rt.Token("OP", ".")) + new_tokens.pop() + + return lib_name_counter + def _unasync_name(self, name): if name in self.token_replacements: return self.token_replacements[name] @@ -141,16 +168,16 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) +_IMPORTS_REPLACE["ably.http.paginatedresult"] = "ably.nako.paginatedresult" Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) -_ASYNC_TO_SYNC["http"] = "ably.sync.http.paginatedresult" - src_dir_path = os.path.join(os.getcwd(), "ably", "rest") dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "rest") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) os.makedirs(dest_dir_path, exist_ok=True) + def find_files(dir_path, file_name_regex) -> list[str]: return glob.glob(os.path.join(dir_path, "*" + file_name_regex)) @@ -164,7 +191,6 @@ def find_files(dir_path, file_name_regex) -> list[str]: dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - src_files = find_files(src_dir_path, ".py") unasync_files(src_files, (_DEFAULT_RULE,)) @@ -175,12 +201,10 @@ def find_files(dir_path, file_name_regex) -> list[str]: dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "types") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - src_files = find_files(src_dir_path, "presence.py") unasync_files(src_files, (_DEFAULT_RULE,)) - # class _build_py(orig.build_py): # """ # Subclass build_py from setuptools to modify its behavior. From 1b5b4e04bb45710eb7ebdbe3bd63810fcd9943c2 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Wed, 4 Oct 2023 23:42:51 +0530 Subject: [PATCH 05/52] Updated code for replacing imports --- unasync.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/unasync.py b/unasync.py index f4461a38..c2418301 100644 --- a/unasync.py +++ b/unasync.py @@ -136,12 +136,16 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): full_lib_name = full_lib_name + tokens[lib_name_counter].src lib_name_counter = lib_name_counter + 1 - if full_lib_name in _IMPORTS_REPLACE: - for lib_name_token in _IMPORTS_REPLACE[full_lib_name].split("."): - new_tokens.append(tokenize_rt.Token("NAME", lib_name_token)) - new_tokens.append(tokenize_rt.Token("OP", ".")) - new_tokens.pop() - + for key, value in _IMPORTS_REPLACE.items(): + if key in full_lib_name: + updated_lib_name = full_lib_name.replace(key, value) + for lib_name_part in updated_lib_name.split("."): + new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) + new_tokens.append(tokenize_rt.Token("OP", ".")) + if full_lib_name == key: + new_tokens.pop() + else: + lib_name_counter = token_counter + 2 return lib_name_counter def _unasync_name(self, name): @@ -168,7 +172,7 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -_IMPORTS_REPLACE["ably.http.paginatedresult"] = "ably.nako.paginatedresult" +_IMPORTS_REPLACE["ably.http.paginatedresult"] = "ably.dong.paginatedresult" Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) src_dir_path = os.path.join(os.getcwd(), "ably", "rest") From 9af0ffc6bbb675974cab87e63f9866c0566b2d1e Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 12:25:02 +0530 Subject: [PATCH 06/52] Refactored unasync file for fixing imports --- unasync.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/unasync.py b/unasync.py index c2418301..1b6fa3ef 100644 --- a/unasync.py +++ b/unasync.py @@ -20,6 +20,8 @@ # is 'raise StopAsyncIteration' -> 'return' since we want to use unasynced # code in Python 3.7+ "StopAsyncIteration": "StopIteration", + "AsyncClient": "Client", + "aclose": "close" } _IMPORTS_REPLACE = { @@ -76,15 +78,15 @@ def _unasync_tokens(self, tokens: list): token = tokens[token_counter] if token.src in ["async", "await"]: - token_counter = token_counter + 1 # When removing async or await, we want to skip the following whitespace + token_counter = token_counter + 2 # When removing async or await, we want to skip the following whitespace continue elif token.name == "NAME": if token.src == "from": if tokens[token_counter + 1].src == " ": token_counter = self._replace_import(tokens, token_counter, new_tokens) continue - else: - token = token._replace(src=self._unasync_name(token.src)) + else: + token = token._replace(src=self._unasync_name(token.src)) elif token.name == "STRING": left_quote, name, right_quote = ( token.src[0], @@ -130,6 +132,9 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): full_lib_name = '' lib_name_counter = token_counter + 2 + if len(_IMPORTS_REPLACE.keys()) == 0: + return lib_name_counter + while True: if tokens[lib_name_counter].src == " ": break @@ -142,18 +147,18 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): for lib_name_part in updated_lib_name.split("."): new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) new_tokens.append(tokenize_rt.Token("OP", ".")) - if full_lib_name == key: - new_tokens.pop() - else: - lib_name_counter = token_counter + 2 + new_tokens.pop() + return lib_name_counter + + lib_name_counter = token_counter + 2 return lib_name_counter def _unasync_name(self, name): if name in self.token_replacements: return self.token_replacements[name] # Convert classes prefixed with 'Async' into 'Sync' - elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): - return "Sync" + name[5:] + # elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): + # return "Sync" + name[5:] return name @@ -172,7 +177,10 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -_IMPORTS_REPLACE["ably.http.paginatedresult"] = "ably.dong.paginatedresult" +_IMPORTS_REPLACE["ably.http"] = "ably.sync.http" +_IMPORTS_REPLACE["ably.rest"] = "ably.sync.rest" +# _IMPORTS_REPLACE["ably.types"] = "ably.types.sync" + Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) src_dir_path = os.path.join(os.getcwd(), "ably", "rest") @@ -183,10 +191,10 @@ def unasync_files(fpath_list, rules): def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, "*" + file_name_regex)) + return glob.glob(os.path.join(dir_path, file_name_regex)) -src_files = find_files(src_dir_path, ".py") +src_files = find_files(src_dir_path, "*.py") unasync_files(src_files, (_DEFAULT_RULE,)) @@ -195,7 +203,7 @@ def find_files(dir_path, file_name_regex) -> list[str]: dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) -src_files = find_files(src_dir_path, ".py") +src_files = find_files(src_dir_path, "*.py") unasync_files(src_files, (_DEFAULT_RULE,)) From 30bf9c382bec0b36c75ee402a60b0a6bb00d9640 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 13:12:04 +0530 Subject: [PATCH 07/52] Added unasync_test file for generating tests --- unasync_test.py | 261 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 unasync_test.py diff --git a/unasync_test.py b/unasync_test.py new file mode 100644 index 00000000..1b6fa3ef --- /dev/null +++ b/unasync_test.py @@ -0,0 +1,261 @@ +"""Top-level package for unasync.""" + +import collections +import glob +import os +import tokenize as std_tokenize + +import tokenize_rt + +_ASYNC_TO_SYNC = { + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "__aiter__": "__iter__", + "__anext__": "__next__", + "asynccontextmanager": "contextmanager", + "AsyncIterable": "Iterable", + "AsyncIterator": "Iterator", + "AsyncGenerator": "Generator", + # TODO StopIteration is still accepted in Python 2, but the right change + # is 'raise StopAsyncIteration' -> 'return' since we want to use unasynced + # code in Python 3.7+ + "StopAsyncIteration": "StopIteration", + "AsyncClient": "Client", + "aclose": "close" +} + +_IMPORTS_REPLACE = { + +} + + +class Rule: + """A single set of rules for 'unasync'ing file(s)""" + + def __init__(self, fromdir, todir, additional_replacements=None): + self.fromdir = fromdir.replace("/", os.sep) + self.todir = todir.replace("/", os.sep) + + # Add any additional user-defined token replacements to our list. + self.token_replacements = _ASYNC_TO_SYNC.copy() + for key, val in (additional_replacements or {}).items(): + self.token_replacements[key] = val + + def _match(self, filepath): + """Determines if a Rule matches a given filepath and if so + returns a higher comparable value if the match is more specific. + """ + file_segments = [x for x in filepath.split(os.sep) if x] + from_segments = [x for x in self.fromdir.split(os.sep) if x] + len_from_segments = len(from_segments) + + if len_from_segments > len(file_segments): + return False + + for i in range(len(file_segments) - len_from_segments + 1): + if file_segments[i: i + len_from_segments] == from_segments: + return len_from_segments, i + + return False + + def _unasync_file(self, filepath): + with open(filepath, "rb") as f: + encoding, _ = std_tokenize.detect_encoding(f.readline) + + with open(filepath, "rt", encoding=encoding) as f: + tokens = tokenize_rt.src_to_tokens(f.read()) + tokens = self._unasync_tokens(tokens) + result = tokenize_rt.tokens_to_src(tokens) + outfilepath = filepath.replace(self.fromdir, self.todir) + os.makedirs(os.path.dirname(outfilepath), exist_ok=True) + with open(outfilepath, "wb") as f: + f.write(result.encode(encoding)) + + def _unasync_tokens(self, tokens: list): + new_tokens = [] + token_counter = 0 + while token_counter < len(tokens): + token = tokens[token_counter] + + if token.src in ["async", "await"]: + token_counter = token_counter + 2 # When removing async or await, we want to skip the following whitespace + continue + elif token.name == "NAME": + if token.src == "from": + if tokens[token_counter + 1].src == " ": + token_counter = self._replace_import(tokens, token_counter, new_tokens) + continue + else: + token = token._replace(src=self._unasync_name(token.src)) + elif token.name == "STRING": + left_quote, name, right_quote = ( + token.src[0], + token.src[1:-1], + token.src[-1], + ) + token = token._replace( + src=left_quote + self._unasync_name(name) + right_quote + ) + + new_tokens.append(token) + token_counter = token_counter + 1 + + return new_tokens + + # for i, token in enumerate(tokens): + # if skip_next: + # skip_next = False + # continue + # + # if token.src in ["async", "await"]: + # # When removing async or await, we want to skip the following whitespace + # # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` + # skip_next = True + # else: + # if token.name == "NAME": + # token = token._replace(src=self._unasync_name(token.src)) + # elif token.name == "STRING": + # left_quote, name, right_quote = ( + # token.src[0], + # token.src[1:-1], + # token.src[-1], + # ) + # token = token._replace( + # src=left_quote + self._unasync_name(name) + right_quote + # ) + # + # yield token + + def _replace_import(self, tokens, token_counter, new_tokens: list): + new_tokens.append(tokens[token_counter]) + new_tokens.append(tokens[token_counter + 1]) + + full_lib_name = '' + lib_name_counter = token_counter + 2 + if len(_IMPORTS_REPLACE.keys()) == 0: + return lib_name_counter + + while True: + if tokens[lib_name_counter].src == " ": + break + full_lib_name = full_lib_name + tokens[lib_name_counter].src + lib_name_counter = lib_name_counter + 1 + + for key, value in _IMPORTS_REPLACE.items(): + if key in full_lib_name: + updated_lib_name = full_lib_name.replace(key, value) + for lib_name_part in updated_lib_name.split("."): + new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) + new_tokens.append(tokenize_rt.Token("OP", ".")) + new_tokens.pop() + return lib_name_counter + + lib_name_counter = token_counter + 2 + return lib_name_counter + + def _unasync_name(self, name): + if name in self.token_replacements: + return self.token_replacements[name] + # Convert classes prefixed with 'Async' into 'Sync' + # elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): + # return "Sync" + name[5:] + return name + + +def unasync_files(fpath_list, rules): + for f in fpath_list: + found_rule = None + found_weight = None + + for rule in rules: + weight = rule._match(f) + if weight and (found_weight is None or weight > found_weight): + found_rule = rule + found_weight = weight + + if found_rule: + found_rule._unasync_file(f) + + +_IMPORTS_REPLACE["ably.http"] = "ably.sync.http" +_IMPORTS_REPLACE["ably.rest"] = "ably.sync.rest" +# _IMPORTS_REPLACE["ably.types"] = "ably.types.sync" + +Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) + +src_dir_path = os.path.join(os.getcwd(), "ably", "rest") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "rest") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +os.makedirs(dest_dir_path, exist_ok=True) + + +def find_files(dir_path, file_name_regex) -> list[str]: + return glob.glob(os.path.join(dir_path, file_name_regex)) + + +src_files = find_files(src_dir_path, "*.py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + +# round 2 +src_dir_path = os.path.join(os.getcwd(), "ably", "http") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +src_files = find_files(src_dir_path, "*.py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + +# round 3 + +src_dir_path = os.path.join(os.getcwd(), "ably", "types") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "types") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +src_files = find_files(src_dir_path, "presence.py") + +unasync_files(src_files, (_DEFAULT_RULE,)) + +# class _build_py(orig.build_py): +# """ +# Subclass build_py from setuptools to modify its behavior. +# +# Convert files in _async dir from being asynchronous to synchronous +# and saves them in _sync dir. +# """ +# +# UNASYNC_RULES = (_DEFAULT_RULE,) +# +# def run(self): +# rules = self.UNASYNC_RULES +# +# self._updated_files = [] +# +# # Base class code +# if self.py_modules: +# self.build_modules() +# if self.packages: +# self.build_packages() +# self.build_package_data() +# +# # Our modification! +# unasync_files(self._updated_files, rules) +# +# # Remaining base class code +# self.byte_compile(self.get_outputs(include_bytecode=0)) +# +# def build_module(self, module, module_file, package): +# outfile, copied = super().build_module(module, module_file, package) +# if copied: +# self._updated_files.append(outfile) +# return outfile, copied +# +# +# def cmdclass_build_py(rules=(_DEFAULT_RULE,)): +# """Creates a 'build_py' class for use within 'cmdclass={"build_py": ...}'""" +# +# class _custom_build_py(_build_py): +# UNASYNC_RULES = rules +# +# return _custom_build_py From 9cad493a8485b8e1c7afa7de7a57ee95c5df6323 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 13:12:25 +0530 Subject: [PATCH 08/52] Updated unasync test file for generating rest only tests --- unasync_test.py | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/unasync_test.py b/unasync_test.py index 1b6fa3ef..96b7c721 100644 --- a/unasync_test.py +++ b/unasync_test.py @@ -21,13 +21,19 @@ # code in Python 3.7+ "StopAsyncIteration": "StopIteration", "AsyncClient": "Client", - "aclose": "close" + "aclose": "close", + "asyncSetUp": "setUp", + "asyncTearDown": "tearDown" } _IMPORTS_REPLACE = { } +_STRING_REPLACE = { + '/../assets/testAppSpec.json': '/../../assets/testAppSpec.json' +} + class Rule: """A single set of rules for 'unasync'ing file(s)""" @@ -76,7 +82,8 @@ def _unasync_tokens(self, tokens: list): token_counter = 0 while token_counter < len(tokens): token = tokens[token_counter] - + if token.src == "'/../assets/testAppSpec.json'": + print("hi") if token.src in ["async", "await"]: token_counter = token_counter + 2 # When removing async or await, we want to skip the following whitespace continue @@ -88,14 +95,10 @@ def _unasync_tokens(self, tokens: list): else: token = token._replace(src=self._unasync_name(token.src)) elif token.name == "STRING": - left_quote, name, right_quote = ( - token.src[0], - token.src[1:-1], - token.src[-1], - ) - token = token._replace( - src=left_quote + self._unasync_name(name) + right_quote - ) + srcToken = token.src.replace("'", "") + if _STRING_REPLACE.get(srcToken) != None: + resulting_token = f"'{_STRING_REPLACE[srcToken]}'" + token = token._replace(src=resulting_token) new_tokens.append(token) token_counter = token_counter + 1 @@ -179,41 +182,39 @@ def unasync_files(fpath_list, rules): _IMPORTS_REPLACE["ably.http"] = "ably.sync.http" _IMPORTS_REPLACE["ably.rest"] = "ably.sync.rest" -# _IMPORTS_REPLACE["ably.types"] = "ably.types.sync" +_IMPORTS_REPLACE["test.ably.testapp"] = "test.ably.sync.testapp" +_IMPORTS_REPLACE["test.ably.utils"] = "test.ably.sync.utils" Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) -src_dir_path = os.path.join(os.getcwd(), "ably", "rest") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "rest") +src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) os.makedirs(dest_dir_path, exist_ok=True) def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, file_name_regex)) + return glob.glob(os.path.join(dir_path, file_name_regex), recursive=True) src_files = find_files(src_dir_path, "*.py") - unasync_files(src_files, (_DEFAULT_RULE,)) # round 2 -src_dir_path = os.path.join(os.getcwd(), "ably", "http") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") +src_dir_path = os.path.join(os.getcwd(), "test", "ably") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) -src_files = find_files(src_dir_path, "*.py") +src_files = find_files(src_dir_path, "testapp.py") unasync_files(src_files, (_DEFAULT_RULE,)) -# round 3 - -src_dir_path = os.path.join(os.getcwd(), "ably", "types") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "types") +src_dir_path = os.path.join(os.getcwd(), "test", "ably") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) -src_files = find_files(src_dir_path, "presence.py") +src_files = find_files(src_dir_path, "utils.py") unasync_files(src_files, (_DEFAULT_RULE,)) From d952ab006c892d12d0422b327536704c4d59bba2 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 14:50:42 +0530 Subject: [PATCH 09/52] Refactored unasync file, removed unnecessary build module code --- unasync.py | 77 +++++------------------------------------------------- 1 file changed, 7 insertions(+), 70 deletions(-) diff --git a/unasync.py b/unasync.py index 1b6fa3ef..454963f3 100644 --- a/unasync.py +++ b/unasync.py @@ -177,85 +177,22 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -_IMPORTS_REPLACE["ably.http"] = "ably.sync.http" -_IMPORTS_REPLACE["ably.rest"] = "ably.sync.rest" -# _IMPORTS_REPLACE["ably.types"] = "ably.types.sync" +_IMPORTS_REPLACE["ably"] = "ably.sync" Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) -src_dir_path = os.path.join(os.getcwd(), "ably", "rest") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "rest") +src_dir_path = os.path.join(os.getcwd(), "ably") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) os.makedirs(dest_dir_path, exist_ok=True) def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, file_name_regex)) + return glob.glob(os.path.join(dir_path, "**", file_name_regex), recursive=True) -src_files = find_files(src_dir_path, "*.py") +relevant_src_files = (set(find_files(src_dir_path, "*.py")) - + set(find_files(dest_dir_path, "*.py"))) -unasync_files(src_files, (_DEFAULT_RULE,)) - -# round 2 -src_dir_path = os.path.join(os.getcwd(), "ably", "http") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "http") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -src_files = find_files(src_dir_path, "*.py") - -unasync_files(src_files, (_DEFAULT_RULE,)) - -# round 3 - -src_dir_path = os.path.join(os.getcwd(), "ably", "types") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync", "types") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -src_files = find_files(src_dir_path, "presence.py") - -unasync_files(src_files, (_DEFAULT_RULE,)) - -# class _build_py(orig.build_py): -# """ -# Subclass build_py from setuptools to modify its behavior. -# -# Convert files in _async dir from being asynchronous to synchronous -# and saves them in _sync dir. -# """ -# -# UNASYNC_RULES = (_DEFAULT_RULE,) -# -# def run(self): -# rules = self.UNASYNC_RULES -# -# self._updated_files = [] -# -# # Base class code -# if self.py_modules: -# self.build_modules() -# if self.packages: -# self.build_packages() -# self.build_package_data() -# -# # Our modification! -# unasync_files(self._updated_files, rules) -# -# # Remaining base class code -# self.byte_compile(self.get_outputs(include_bytecode=0)) -# -# def build_module(self, module, module_file, package): -# outfile, copied = super().build_module(module, module_file, package) -# if copied: -# self._updated_files.append(outfile) -# return outfile, copied -# -# -# def cmdclass_build_py(rules=(_DEFAULT_RULE,)): -# """Creates a 'build_py' class for use within 'cmdclass={"build_py": ...}'""" -# -# class _custom_build_py(_build_py): -# UNASYNC_RULES = rules -# -# return _custom_build_py +unasync_files(list(relevant_src_files), (_DEFAULT_RULE,)) From d83597fbcd4b2870ee6622eee1b3c569a1896a75 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 14:51:01 +0530 Subject: [PATCH 10/52] Refactored unasync_test, removed unnecessary module code --- unasync.py | 3 ++- unasync_test.py | 71 +++++++++---------------------------------------- 2 files changed, 14 insertions(+), 60 deletions(-) diff --git a/unasync.py b/unasync.py index 454963f3..73a70651 100644 --- a/unasync.py +++ b/unasync.py @@ -78,7 +78,8 @@ def _unasync_tokens(self, tokens: list): token = tokens[token_counter] if token.src in ["async", "await"]: - token_counter = token_counter + 2 # When removing async or await, we want to skip the following whitespace + # When removing async or await, we want to skip the following whitespace + token_counter = token_counter + 2 continue elif token.name == "NAME": if token.src == "from": diff --git a/unasync_test.py b/unasync_test.py index 96b7c721..0743ab07 100644 --- a/unasync_test.py +++ b/unasync_test.py @@ -23,7 +23,8 @@ "AsyncClient": "Client", "aclose": "close", "asyncSetUp": "setUp", - "asyncTearDown": "tearDown" + "asyncTearDown": "tearDown", + "AsyncMock": "Mock" } _IMPORTS_REPLACE = { @@ -31,7 +32,6 @@ } _STRING_REPLACE = { - '/../assets/testAppSpec.json': '/../../assets/testAppSpec.json' } @@ -85,7 +85,8 @@ def _unasync_tokens(self, tokens: list): if token.src == "'/../assets/testAppSpec.json'": print("hi") if token.src in ["async", "await"]: - token_counter = token_counter + 2 # When removing async or await, we want to skip the following whitespace + # When removing async or await, we want to skip the following whitespace + token_counter = token_counter + 2 continue elif token.name == "NAME": if token.src == "from": @@ -180,10 +181,12 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -_IMPORTS_REPLACE["ably.http"] = "ably.sync.http" -_IMPORTS_REPLACE["ably.rest"] = "ably.sync.rest" -_IMPORTS_REPLACE["test.ably.testapp"] = "test.ably.sync.testapp" -_IMPORTS_REPLACE["test.ably.utils"] = "test.ably.sync.utils" +_IMPORTS_REPLACE["ably"] = "ably.sync" +_IMPORTS_REPLACE["test.ably"] = "test.ably.sync" + +_STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' +_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.Auth.request_token' +_STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) @@ -206,57 +209,7 @@ def find_files(dir_path, file_name_regex) -> list[str]: dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) -src_files = find_files(src_dir_path, "testapp.py") +src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), + os.path.join(os.getcwd(), "test", "ably", "utils.py")] unasync_files(src_files, (_DEFAULT_RULE,)) - -src_dir_path = os.path.join(os.getcwd(), "test", "ably") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -src_files = find_files(src_dir_path, "utils.py") - -unasync_files(src_files, (_DEFAULT_RULE,)) - -# class _build_py(orig.build_py): -# """ -# Subclass build_py from setuptools to modify its behavior. -# -# Convert files in _async dir from being asynchronous to synchronous -# and saves them in _sync dir. -# """ -# -# UNASYNC_RULES = (_DEFAULT_RULE,) -# -# def run(self): -# rules = self.UNASYNC_RULES -# -# self._updated_files = [] -# -# # Base class code -# if self.py_modules: -# self.build_modules() -# if self.packages: -# self.build_packages() -# self.build_package_data() -# -# # Our modification! -# unasync_files(self._updated_files, rules) -# -# # Remaining base class code -# self.byte_compile(self.get_outputs(include_bytecode=0)) -# -# def build_module(self, module, module_file, package): -# outfile, copied = super().build_module(module, module_file, package) -# if copied: -# self._updated_files.append(outfile) -# return outfile, copied -# -# -# def cmdclass_build_py(rules=(_DEFAULT_RULE,)): -# """Creates a 'build_py' class for use within 'cmdclass={"build_py": ...}'""" -# -# class _custom_build_py(_build_py): -# UNASYNC_RULES = rules -# -# return _custom_build_py From e80c1e6fc48c65681636ccdde88da4fc609a927e Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 14:56:33 +0530 Subject: [PATCH 11/52] Fixed flake8 issues for unasync_test file --- unasync_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/unasync_test.py b/unasync_test.py index 0743ab07..692e86cb 100644 --- a/unasync_test.py +++ b/unasync_test.py @@ -82,8 +82,6 @@ def _unasync_tokens(self, tokens: list): token_counter = 0 while token_counter < len(tokens): token = tokens[token_counter] - if token.src == "'/../assets/testAppSpec.json'": - print("hi") if token.src in ["async", "await"]: # When removing async or await, we want to skip the following whitespace token_counter = token_counter + 2 @@ -96,10 +94,10 @@ def _unasync_tokens(self, tokens: list): else: token = token._replace(src=self._unasync_name(token.src)) elif token.name == "STRING": - srcToken = token.src.replace("'", "") - if _STRING_REPLACE.get(srcToken) != None: - resulting_token = f"'{_STRING_REPLACE[srcToken]}'" - token = token._replace(src=resulting_token) + src_token = token.src.replace("'", "") + if _STRING_REPLACE.get(src_token) is not None: + new_token = f"'{_STRING_REPLACE[src_token]}'" + token = token._replace(src=new_token) new_tokens.append(token) token_counter = token_counter + 1 From 67434a5249e1c66886bf2e0c4dfb760993d283fc Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 14:57:56 +0530 Subject: [PATCH 12/52] Added IDE specific files to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 71554b60..0d07b9f2 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,5 @@ app_spec app_spec.pkl ably/types/options.py.orig test/ably/restsetup.py.orig + +.idea/**/* \ No newline at end of file From f2f89cc4db324156552ae0b14daf23ae4a999113 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 15:03:04 +0530 Subject: [PATCH 13/52] Created sync directory to maintain generated sync code --- ably/sync/__init__.py | 18 + ably/sync/http/__init__.py | 0 ably/sync/http/http.py | 301 ++++++++++++ ably/sync/http/httputils.py | 55 +++ ably/sync/http/paginatedresult.py | 134 ++++++ ably/sync/realtime/__init__.py | 0 ably/sync/realtime/connection.py | 119 +++++ ably/sync/realtime/connectionmanager.py | 524 ++++++++++++++++++++ ably/sync/realtime/realtime.py | 140 ++++++ ably/sync/realtime/realtime_channel.py | 553 ++++++++++++++++++++++ ably/sync/rest/__init__.py | 0 ably/sync/rest/auth.py | 425 +++++++++++++++++ ably/sync/rest/channel.py | 229 +++++++++ ably/sync/rest/push.py | 189 ++++++++ ably/sync/rest/rest.py | 148 ++++++ ably/sync/transport/__init__.py | 0 ably/sync/transport/defaults.py | 63 +++ ably/sync/transport/websockettransport.py | 219 +++++++++ ably/sync/types/__init__.py | 0 ably/sync/types/authoptions.py | 157 ++++++ ably/sync/types/capability.py | 82 ++++ ably/sync/types/channeldetails.py | 116 +++++ ably/sync/types/channelstate.py | 22 + ably/sync/types/channelsubscription.py | 70 +++ ably/sync/types/connectiondetails.py | 20 + ably/sync/types/connectionerrors.py | 30 ++ ably/sync/types/connectionstate.py | 36 ++ ably/sync/types/device.py | 116 +++++ ably/sync/types/flags.py | 19 + ably/sync/types/message.py | 233 +++++++++ ably/sync/types/mixins.py | 75 +++ ably/sync/types/options.py | 330 +++++++++++++ ably/sync/types/presence.py | 174 +++++++ ably/sync/types/stats.py | 67 +++ ably/sync/types/tokendetails.py | 97 ++++ ably/sync/types/tokenrequest.py | 107 +++++ ably/sync/types/typedbuffer.py | 104 ++++ ably/sync/util/__init__.py | 0 ably/sync/util/case.py | 18 + ably/sync/util/crypto.py | 179 +++++++ ably/sync/util/eventemitter.py | 185 ++++++++ ably/sync/util/exceptions.py | 92 ++++ ably/sync/util/helper.py | 42 ++ ably/sync/util/nocrypto.py | 9 + 44 files changed, 5497 insertions(+) create mode 100644 ably/sync/__init__.py create mode 100644 ably/sync/http/__init__.py create mode 100644 ably/sync/http/http.py create mode 100644 ably/sync/http/httputils.py create mode 100644 ably/sync/http/paginatedresult.py create mode 100644 ably/sync/realtime/__init__.py create mode 100644 ably/sync/realtime/connection.py create mode 100644 ably/sync/realtime/connectionmanager.py create mode 100644 ably/sync/realtime/realtime.py create mode 100644 ably/sync/realtime/realtime_channel.py create mode 100644 ably/sync/rest/__init__.py create mode 100644 ably/sync/rest/auth.py create mode 100644 ably/sync/rest/channel.py create mode 100644 ably/sync/rest/push.py create mode 100644 ably/sync/rest/rest.py create mode 100644 ably/sync/transport/__init__.py create mode 100644 ably/sync/transport/defaults.py create mode 100644 ably/sync/transport/websockettransport.py create mode 100644 ably/sync/types/__init__.py create mode 100644 ably/sync/types/authoptions.py create mode 100644 ably/sync/types/capability.py create mode 100644 ably/sync/types/channeldetails.py create mode 100644 ably/sync/types/channelstate.py create mode 100644 ably/sync/types/channelsubscription.py create mode 100644 ably/sync/types/connectiondetails.py create mode 100644 ably/sync/types/connectionerrors.py create mode 100644 ably/sync/types/connectionstate.py create mode 100644 ably/sync/types/device.py create mode 100644 ably/sync/types/flags.py create mode 100644 ably/sync/types/message.py create mode 100644 ably/sync/types/mixins.py create mode 100644 ably/sync/types/options.py create mode 100644 ably/sync/types/presence.py create mode 100644 ably/sync/types/stats.py create mode 100644 ably/sync/types/tokendetails.py create mode 100644 ably/sync/types/tokenrequest.py create mode 100644 ably/sync/types/typedbuffer.py create mode 100644 ably/sync/util/__init__.py create mode 100644 ably/sync/util/case.py create mode 100644 ably/sync/util/crypto.py create mode 100644 ably/sync/util/eventemitter.py create mode 100644 ably/sync/util/exceptions.py create mode 100644 ably/sync/util/helper.py create mode 100644 ably/sync/util/nocrypto.py diff --git a/ably/sync/__init__.py b/ably/sync/__init__.py new file mode 100644 index 00000000..296dbf0d --- /dev/null +++ b/ably/sync/__init__.py @@ -0,0 +1,18 @@ +from ably.sync.rest.rest import AblyRest +from ably.sync.realtime.realtime import AblyRealtime +from ably.sync.rest.auth import Auth +from ably.sync.rest.push import Push +from ably.sync.types.capability import Capability +from ably.sync.types.channelsubscription import PushChannelSubscription +from ably.sync.types.device import DeviceDetails +from ably.sync.types.options import Options +from ably.sync.util.crypto import CipherParams +from ably.sync.util.exceptions import AblyException, AblyAuthException, IncompatibleClientIdException + +import logging + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +api_version = '3' +lib_version = '2.0.2' diff --git a/ably/sync/http/__init__.py b/ably/sync/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/http/http.py b/ably/sync/http/http.py new file mode 100644 index 00000000..8e52da55 --- /dev/null +++ b/ably/sync/http/http.py @@ -0,0 +1,301 @@ +import functools +import logging +import time +import json +from urllib.parse import urljoin + +import httpx +import msgpack + +from ably.sync.rest.auth import Auth +from ably.sync.http.httputils import HttpUtils +from ably.sync.transport.defaults import Defaults +from ably.sync.util.exceptions import AblyException +from ably.sync.util.helper import is_token_error + +log = logging.getLogger(__name__) + + +def reauth_if_expired(func): + @functools.wraps(func) + def wrapper(rest, *args, **kwargs): + if kwargs.get("skip_auth"): + return func(rest, *args, **kwargs) + + # RSA4b1 Detect expired token to avoid round-trip request + auth = rest.auth + token_details = auth.token_details + if token_details and auth.time_offset is not None and auth.token_details_has_expired(): + auth.authorize() + retried = True + else: + retried = False + + try: + return func(rest, *args, **kwargs) + except AblyException as e: + if is_token_error(e) and not retried: + auth.authorize() + return func(rest, *args, **kwargs) + + raise e + + return wrapper + + +class Request: + def __init__(self, method='GET', url='/', version=None, headers=None, body=None, + skip_auth=False, raise_on_error=True): + self.__method = method + self.__headers = headers or {} + self.__body = body + self.__skip_auth = skip_auth + self.__url = url + self.__version = version + self.raise_on_error = raise_on_error + + def with_relative_url(self, relative_url): + url = urljoin(self.url, relative_url) + return Request(self.method, url, self.version, self.headers, self.body, + self.skip_auth, self.raise_on_error) + + @property + def method(self): + return self.__method + + @property + def url(self): + return self.__url + + @property + def headers(self): + return self.__headers + + @property + def body(self): + return self.__body + + @property + def skip_auth(self): + return self.__skip_auth + + @property + def version(self): + return self.__version + + +class Response: + """ + Composition for httpx.Response with delegation + """ + + def __init__(self, response): + self.__response = response + + def to_native(self): + content = self.__response.content + if not content: + return None + + content_type = self.__response.headers.get('content-type') + if isinstance(content_type, str): + if content_type.startswith('application/x-msgpack'): + return msgpack.unpackb(content) + elif content_type.startswith('application/json'): + return self.__response.json() + + raise ValueError("Unsupported content type") + + @property + def response(self): + return self.__response + + def __getattr__(self, attr): + return getattr(self.__response, attr) + + +class Http: + CONNECTION_RETRY_DEFAULTS = { + 'http_open_timeout': 4, + 'http_request_timeout': 10, + 'http_max_retry_duration': 15, + } + + def __init__(self, ably, options): + options = options or {} + self.__ably = ably + self.__options = options + self.__auth = None + # Cached fallback host (RSC15f) + self.__host = None + self.__host_expires = None + self.__client = httpx.Client(http2=True) + + def close(self): + self.__client.close() + + def dump_body(self, body): + if self.options.use_binary_protocol: + return msgpack.packb(body, use_bin_type=False) + else: + return json.dumps(body, separators=(',', ':')) + + def get_rest_hosts(self): + hosts = self.options.get_rest_hosts() + host = self.__host or self.options.fallback_realtime_host + if host is None: + return hosts + + if time.time() > self.__host_expires: + self.__host = None + self.__host_expires = None + return hosts + + hosts = list(hosts) + hosts.remove(host) + hosts.insert(0, host) + return hosts + + @reauth_if_expired + def make_request(self, method, path, version=None, headers=None, body=None, + skip_auth=False, timeout=None, raise_on_error=True): + + if body is not None and type(body) not in (bytes, str): + body = self.dump_body(body) + + if body: + all_headers = HttpUtils.default_post_headers(self.options.use_binary_protocol, version=version) + else: + all_headers = HttpUtils.default_get_headers(self.options.use_binary_protocol, version=version) + + params = HttpUtils.get_query_params(self.options) + + if not skip_auth: + if self.auth.auth_mechanism == Auth.Method.BASIC and self.preferred_scheme.lower() == 'http': + raise AblyException( + "Cannot use Basic Auth over non-TLS connections", + 401, + 40103) + auth_headers = self.auth._get_auth_headers() + all_headers.update(auth_headers) + if headers: + all_headers.update(headers) + + timeout = (self.http_open_timeout, self.http_request_timeout) + http_max_retry_duration = self.http_max_retry_duration + requested_at = time.time() + + hosts = self.get_rest_hosts() + for retry_count, host in enumerate(hosts): + base_url = "%s://%s:%d" % (self.preferred_scheme, + host, + self.preferred_port) + url = urljoin(base_url, path) + + request = self.__client.build_request( + method=method, + url=url, + content=body, + params=params, + headers=all_headers, + timeout=timeout, + ) + try: + response = self.__client.send(request) + except Exception as e: + # if last try or cumulative timeout is done, throw exception up + time_passed = time.time() - requested_at + if retry_count == len(hosts) - 1 or time_passed > http_max_retry_duration: + raise e + else: + try: + if raise_on_error: + AblyException.raise_for_response(response) + + # Keep fallback host for later (RSC15f) + if retry_count > 0 and host != self.options.get_rest_host(): + self.__host = host + self.__host_expires = time.time() + (self.options.fallback_retry_timeout / 1000.0) + + return Response(response) + except AblyException as e: + if not e.is_server_error: + raise e + + # if last try or cumulative timeout is done, throw exception up + time_passed = time.time() - requested_at + if retry_count == len(hosts) - 1 or time_passed > http_max_retry_duration: + raise e + + def delete(self, url, headers=None, skip_auth=False, timeout=None): + result = self.make_request('DELETE', url, headers=headers, + skip_auth=skip_auth, timeout=timeout) + return result + + def get(self, url, headers=None, skip_auth=False, timeout=None): + result = self.make_request('GET', url, headers=headers, + skip_auth=skip_auth, timeout=timeout) + return result + + def patch(self, url, headers=None, body=None, skip_auth=False, timeout=None): + result = self.make_request('PATCH', url, headers=headers, body=body, + skip_auth=skip_auth, timeout=timeout) + return result + + def post(self, url, headers=None, body=None, skip_auth=False, timeout=None): + result = self.make_request('POST', url, headers=headers, body=body, + skip_auth=skip_auth, timeout=timeout) + return result + + def put(self, url, headers=None, body=None, skip_auth=False, timeout=None): + result = self.make_request('PUT', url, headers=headers, body=body, + skip_auth=skip_auth, timeout=timeout) + return result + + @property + def auth(self): + return self.__auth + + @auth.setter + def auth(self, value): + self.__auth = value + + @property + def options(self): + return self.__options + + @property + def preferred_host(self): + return self.options.get_rest_host() + + @property + def preferred_port(self): + return Defaults.get_port(self.options) + + @property + def preferred_scheme(self): + return Defaults.get_scheme(self.options) + + @property + def http_open_timeout(self): + if self.options.http_open_timeout is not None: + return self.options.http_open_timeout + return self.CONNECTION_RETRY_DEFAULTS['http_open_timeout'] + + @property + def http_request_timeout(self): + if self.options.http_request_timeout is not None: + return self.options.http_request_timeout + return self.CONNECTION_RETRY_DEFAULTS['http_request_timeout'] + + @property + def http_max_retry_count(self): + if self.options.http_max_retry_count is not None: + return self.options.http_max_retry_count + return self.CONNECTION_RETRY_DEFAULTS['http_max_retry_count'] + + @property + def http_max_retry_duration(self): + if self.options.http_max_retry_duration is not None: + return self.options.http_max_retry_duration + return self.CONNECTION_RETRY_DEFAULTS['http_max_retry_duration'] diff --git a/ably/sync/http/httputils.py b/ably/sync/http/httputils.py new file mode 100644 index 00000000..b55ae75c --- /dev/null +++ b/ably/sync/http/httputils.py @@ -0,0 +1,55 @@ +import base64 +import os +import platform + +import ably + + +class HttpUtils: + default_format = "json" + + mime_types = { + "json": "application/json", + "xml": "application/xml", + "html": "text/html", + "binary": "application/x-msgpack", + } + + @staticmethod + def default_get_headers(binary=False, version=None): + headers = HttpUtils.default_headers(version=version) + if binary: + headers["Accept"] = HttpUtils.mime_types['binary'] + else: + headers["Accept"] = HttpUtils.mime_types['json'] + return headers + + @staticmethod + def default_post_headers(binary=False, version=None): + headers = HttpUtils.default_get_headers(binary=binary, version=version) + headers["Content-Type"] = headers["Accept"] + return headers + + @staticmethod + def get_host_header(host): + return { + 'Host': host, + } + + @staticmethod + def default_headers(version=None): + if version is None: + version = ably.api_version + return { + "X-Ably-Version": version, + "Ably-Agent": 'ably-python/%s python/%s' % (ably.lib_version, platform.python_version()) + } + + @staticmethod + def get_query_params(options): + params = {} + + if options.add_request_ids: + params['request_id'] = base64.urlsafe_b64encode(os.urandom(12)).decode('ascii') + + return params diff --git a/ably/sync/http/paginatedresult.py b/ably/sync/http/paginatedresult.py new file mode 100644 index 00000000..8dbc78ec --- /dev/null +++ b/ably/sync/http/paginatedresult.py @@ -0,0 +1,134 @@ +import calendar +import logging +from urllib.parse import urlencode + +from ably.sync.http.http import Request +from ably.sync.util import case + +log = logging.getLogger(__name__) + + +def format_time_param(t): + try: + return '%d' % (calendar.timegm(t.utctimetuple()) * 1000) + except Exception: + return str(t) + + +def format_params(params=None, direction=None, start=None, end=None, limit=None, **kw): + if params is None: + params = {} + + for key, value in kw.items(): + if value is not None: + key = case.snake_to_camel(key) + params[key] = value + + if direction: + params['direction'] = str(direction) + if start: + params['start'] = format_time_param(start) + if end: + params['end'] = format_time_param(end) + if limit: + if limit > 1000: + raise ValueError("The maximum allowed limit is 1000") + params['limit'] = '%d' % limit + + if 'start' in params and 'end' in params and params['start'] > params['end']: + raise ValueError("'end' parameter has to be greater than or equal to 'start'") + + return '?' + urlencode(params) if params else '' + + +class PaginatedResult: + def __init__(self, http, items, content_type, rel_first, rel_next, + response_processor, response): + self.__http = http + self.__items = items + self.__content_type = content_type + self.__rel_first = rel_first + self.__rel_next = rel_next + self.__response_processor = response_processor + self.response = response + + @property + def items(self): + return self.__items + + def has_first(self): + return self.__rel_first is not None + + def has_next(self): + return self.__rel_next is not None + + def is_last(self): + return not self.has_next() + + def first(self): + return self.__get_rel(self.__rel_first) if self.__rel_first else None + + def next(self): + return self.__get_rel(self.__rel_next) if self.__rel_next else None + + def __get_rel(self, rel_req): + if rel_req is None: + return None + return self.paginated_query_with_request(self.__http, rel_req, self.__response_processor) + + @classmethod + def paginated_query(cls, http, method='GET', url='/', version=None, body=None, + headers=None, response_processor=None, + raise_on_error=True): + headers = headers or {} + req = Request(method, url, version=version, body=body, headers=headers, skip_auth=False, + raise_on_error=raise_on_error) + return cls.paginated_query_with_request(http, req, response_processor) + + @classmethod + def paginated_query_with_request(cls, http, request, response_processor, + raise_on_error=True): + response = http.make_request( + request.method, request.url, version=request.version, + headers=request.headers, body=request.body, + skip_auth=request.skip_auth, raise_on_error=request.raise_on_error) + + items = response_processor(response) + + content_type = response.headers['Content-Type'] + links = response.links + if 'first' in links: + first_rel_request = request.with_relative_url(links['first']['url']) + else: + first_rel_request = None + + if 'next' in links: + next_rel_request = request.with_relative_url(links['next']['url']) + else: + next_rel_request = None + + return cls(http, items, content_type, first_rel_request, + next_rel_request, response_processor, response) + + +class HttpPaginatedResponse(PaginatedResult): + @property + def status_code(self): + return self.response.status_code + + @property + def success(self): + status_code = self.status_code + return 200 <= status_code < 300 + + @property + def error_code(self): + return self.response.headers.get('X-Ably-Errorcode') + + @property + def error_message(self): + return self.response.headers.get('X-Ably-Errormessage') + + @property + def headers(self): + return list(self.response.headers.items()) diff --git a/ably/sync/realtime/__init__.py b/ably/sync/realtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/realtime/connection.py b/ably/sync/realtime/connection.py new file mode 100644 index 00000000..9cf046ff --- /dev/null +++ b/ably/sync/realtime/connection.py @@ -0,0 +1,119 @@ +from __future__ import annotations +import functools +import logging +from ably.sync.realtime.connectionmanager import ConnectionManager +from ably.sync.types.connectiondetails import ConnectionDetails +from ably.sync.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange +from ably.sync.util.eventemitter import EventEmitter +from ably.sync.util.exceptions import AblyException +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ably.sync.realtime.realtime import AblyRealtime + +log = logging.getLogger(__name__) + + +class Connection(EventEmitter): # RTN4 + """Ably Realtime Connection + + Enables the management of a connection to Ably + + Attributes + ---------- + state: str + Connection state + error_reason: ErrorInfo + An ErrorInfo object describing the last error which occurred on the channel, if any. + + + Methods + ------- + connect() + Establishes a realtime connection + close() + Closes a realtime connection + ping() + Pings a realtime connection + """ + + def __init__(self, realtime: AblyRealtime): + self.__realtime = realtime + self.__error_reason: Optional[AblyException] = None + self.__state = ConnectionState.CONNECTING if realtime.options.auto_connect else ConnectionState.INITIALIZED + self.__connection_manager = ConnectionManager(self.__realtime, self.state) + self.__connection_manager.on('connectionstate', self._on_state_update) # RTN4a + self.__connection_manager.on('update', self._on_connection_update) # RTN4h + super().__init__() + + # RTN11 + def connect(self) -> None: + """Establishes a realtime connection. + + Causes the connection to open, entering the connecting state + """ + self.__error_reason = None + self.connection_manager.request_state(ConnectionState.CONNECTING) + + def close(self) -> None: + """Causes the connection to close, entering the closing state. + + Once closed, the library will not attempt to re-establish the + connection without an explicit call to connect() + """ + self.connection_manager.request_state(ConnectionState.CLOSING) + self.once_async(ConnectionState.CLOSED) + + # RTN13 + def ping(self) -> float: + """Send a ping to the realtime connection + + When connected, sends a heartbeat ping to the Ably server and executes + the callback with any error and the response time in milliseconds when + a heartbeat ping request is echoed from the server. + + Raises + ------ + AblyException + If ping request cannot be sent due to invalid state + + Returns + ------- + float + The response time in milliseconds + """ + return self.__connection_manager.ping() + + def _on_state_update(self, state_change: ConnectionStateChange) -> None: + log.info(f'Connection state changing from {self.state} to {state_change.current}') + self.__state = state_change.current + if state_change.reason is not None: + self.__error_reason = state_change.reason + self.__realtime.options.loop.call_soon(functools.partial(self._emit, state_change.current, state_change)) + + def _on_connection_update(self, state_change: ConnectionStateChange) -> None: + self.__realtime.options.loop.call_soon(functools.partial(self._emit, ConnectionEvent.UPDATE, state_change)) + + # RTN4d + @property + def state(self) -> ConnectionState: + """The current connection state of the connection""" + return self.__state + + # RTN25 + @property + def error_reason(self) -> Optional[AblyException]: + """An object describing the last error which occurred on the channel, if any.""" + return self.__error_reason + + @state.setter + def state(self, value: ConnectionState) -> None: + self.__state = value + + @property + def connection_manager(self) -> ConnectionManager: + return self.__connection_manager + + @property + def connection_details(self) -> Optional[ConnectionDetails]: + return self.__connection_manager.connection_details diff --git a/ably/sync/realtime/connectionmanager.py b/ably/sync/realtime/connectionmanager.py new file mode 100644 index 00000000..0be5a427 --- /dev/null +++ b/ably/sync/realtime/connectionmanager.py @@ -0,0 +1,524 @@ +from __future__ import annotations +import logging +import asyncio +import httpx +from ably.sync.transport.websockettransport import WebSocketTransport, ProtocolMessageAction +from ably.sync.transport.defaults import Defaults +from ably.sync.types.connectionerrors import ConnectionErrors +from ably.sync.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange +from ably.sync.types.tokendetails import TokenDetails +from ably.sync.util.exceptions import AblyException, IncompatibleClientIdException +from ably.sync.util.eventemitter import EventEmitter +from datetime import datetime +from ably.sync.util.helper import get_random_id, Timer, is_token_error +from typing import Optional, TYPE_CHECKING +from ably.sync.types.connectiondetails import ConnectionDetails +from queue import Queue + +if TYPE_CHECKING: + from ably.sync.realtime.realtime import AblyRealtime + +log = logging.getLogger(__name__) + + +class ConnectionManager(EventEmitter): + def __init__(self, realtime: AblyRealtime, initial_state): + self.options = realtime.options + self.__ably = realtime + self.__state: ConnectionState = initial_state + self.__ping_future: Optional[asyncio.Future] = None + self.__timeout_in_secs: float = self.options.realtime_request_timeout / 1000 + self.transport: Optional[WebSocketTransport] = None + self.__connection_details: Optional[ConnectionDetails] = None + self.connection_id: Optional[str] = None + self.__fail_state = ConnectionState.DISCONNECTED + self.transition_timer: Optional[Timer] = None + self.suspend_timer: Optional[Timer] = None + self.retry_timer: Optional[Timer] = None + self.connect_base_task: Optional[asyncio.Task] = None + self.disconnect_transport_task: Optional[asyncio.Task] = None + self.__fallback_hosts: list[str] = self.options.get_fallback_realtime_hosts() + self.queued_messages: Queue = Queue() + self.__error_reason: Optional[AblyException] = None + super().__init__() + + def enact_state_change(self, state: ConnectionState, reason: Optional[AblyException] = None) -> None: + current_state = self.__state + log.debug(f'ConnectionManager.enact_state_change(): {current_state} -> {state}; reason = {reason}') + self.__state = state + if reason: + self.__error_reason = reason + self._emit('connectionstate', ConnectionStateChange(current_state, state, state, reason)) + + def check_connection(self) -> bool: + try: + response = httpx.get(self.options.connectivity_check_url) + return 200 <= response.status_code < 300 and \ + (self.options.connectivity_check_url != Defaults.connectivity_check_url or "yes" in response.text) + except httpx.HTTPError: + return False + + def get_state_error(self) -> AblyException: + return ConnectionErrors[self.state] + + def __get_transport_params(self) -> dict: + protocol_version = Defaults.protocol_version + params = self.ably.auth.get_auth_transport_param() + params["v"] = protocol_version + if self.connection_details: + params["resume"] = self.connection_details.connection_key + return params + + def close_impl(self) -> None: + log.debug('ConnectionManager.close_impl()') + + self.cancel_suspend_timer() + self.start_transition_timer(ConnectionState.CLOSING, fail_state=ConnectionState.CLOSED) + if self.transport: + self.transport.dispose() + if self.connect_base_task: + self.connect_base_task.cancel() + if self.disconnect_transport_task: + self.disconnect_transport_task + self.cancel_retry_timer() + + self.notify_state(ConnectionState.CLOSED) + + def send_protocol_message(self, protocol_message: dict) -> None: + if self.state in ( + ConnectionState.DISCONNECTED, + ConnectionState.CONNECTING, + ): + self.queued_messages.put(protocol_message) + return + + if self.state == ConnectionState.CONNECTED: + if self.transport: + self.transport.send(protocol_message) + else: + log.exception( + "ConnectionManager.send_protocol_message(): can not send message with no active transport" + ) + return + + raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000) + + def send_queued_messages(self) -> None: + log.info(f'ConnectionManager.send_queued_messages(): sending {self.queued_messages.qsize()} message(s)') + while not self.queued_messages.empty(): + asyncio.create_task(self.send_protocol_message(self.queued_messages.get())) + + def fail_queued_messages(self, err) -> None: + log.info( + f"ConnectionManager.fail_queued_messages(): discarding {self.queued_messages.qsize()} messages;" + + f" reason = {err}" + ) + while not self.queued_messages.empty(): + msg = self.queued_messages.get() + log.exception(f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: {msg}") + + def ping(self) -> float: + if self.__ping_future: + try: + response = self.__ping_future + except asyncio.CancelledError: + raise AblyException("Ping request cancelled due to request timeout", 504, 50003) + return response + + self.__ping_future = asyncio.Future() + if self.__state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]: + self.__ping_id = get_random_id() + ping_start_time = datetime.now().timestamp() + self.send_protocol_message({"action": ProtocolMessageAction.HEARTBEAT, + "id": self.__ping_id}) + else: + raise AblyException("Cannot send ping request. Calling ping in invalid state", 40000, 400) + try: + asyncio.wait_for(self.__ping_future, self.__timeout_in_secs) + except asyncio.TimeoutError: + raise AblyException("Timeout waiting for ping response", 504, 50003) + + ping_end_time = datetime.now().timestamp() + response_time_ms = (ping_end_time - ping_start_time) * 1000 + return round(response_time_ms, 2) + + def on_connected(self, connection_details: ConnectionDetails, connection_id: str, + reason: Optional[AblyException] = None) -> None: + self.__fail_state = ConnectionState.DISCONNECTED + + self.__connection_details = connection_details + self.connection_id = connection_id + + if connection_details.client_id: + try: + self.ably.auth._configure_client_id(connection_details.client_id) + except IncompatibleClientIdException as e: + self.notify_state(ConnectionState.FAILED, reason=e) + return + + if self.__state == ConnectionState.CONNECTED: + state_change = ConnectionStateChange(ConnectionState.CONNECTED, ConnectionState.CONNECTED, + ConnectionEvent.UPDATE) + self._emit(ConnectionEvent.UPDATE, state_change) + else: + self.notify_state(ConnectionState.CONNECTED, reason=reason) + + self.ably.channels._on_connected() + + def on_disconnected(self, exception: AblyException) -> None: + # RTN15h + if self.transport: + self.transport.dispose() + if exception: + status_code = exception.status_code + if status_code >= 500 and status_code <= 504: # RTN17f1 + if len(self.__fallback_hosts) > 0: + try: + self.connect_with_fallback_hosts(self.__fallback_hosts) + except Exception as e: + self.notify_state(self.__fail_state, reason=e) + return + else: + log.info("No fallback host to try for disconnected protocol message") + elif is_token_error(exception): + self.on_token_error(exception) + else: + self.notify_state(ConnectionState.DISCONNECTED, exception) + else: + log.warn("DISCONNECTED message received without error") + + def on_token_error(self, exception: AblyException) -> None: + if self.__error_reason is None or not is_token_error(self.__error_reason): + self.__error_reason = exception + try: + self.ably.auth._ensure_valid_auth_credentials(force=True) + except Exception as e: + self.on_error_from_authorize(e) + return + self.notify_state(self.__fail_state, exception, retry_immediately=True) + return + self.notify_state(self.__fail_state, exception) + + def on_error(self, msg: dict, exception: AblyException) -> None: + if msg.get("channel") is not None: # RTN15i + self.on_channel_message(msg) + return + if self.transport: + self.transport.dispose() + if is_token_error(exception): # RTN14b + self.on_token_error(exception) + else: + self.enact_state_change(ConnectionState.FAILED, exception) + + def on_error_from_authorize(self, exception: AblyException) -> None: + log.info("ConnectionManager.on_error_from_authorize(): err = %s", exception) + # RSA4a + if exception.code == 40171: + self.notify_state(ConnectionState.FAILED, exception) + elif exception.status_code == 403: + msg = 'Client configured authentication provider returned 403; failing the connection' + log.error(f'ConnectionManager.on_error_from_authorize(): {msg}') + self.notify_state(ConnectionState.FAILED, AblyException(msg, 403, 80019)) + else: + msg = 'Client configured authentication provider request failed' + log.warning(f'ConnectionManager.on_error_from_authorize: {msg}') + self.notify_state(self.__fail_state, AblyException(msg, 401, 80019)) + + def on_closed(self) -> None: + if self.transport: + self.transport.dispose() + if self.connect_base_task: + self.connect_base_task.cancel() + + def on_channel_message(self, msg: dict) -> None: + self.__ably.channels._on_channel_message(msg) + + def on_heartbeat(self, id: Optional[str]) -> None: + if self.__ping_future: + # Resolve on heartbeat from ping request. + if self.__ping_id == id: + if not self.__ping_future.cancelled(): + self.__ping_future.set_result(None) + self.__ping_future = None + + def deactivate_transport(self, reason: Optional[AblyException] = None): + self.transport = None + self.notify_state(ConnectionState.DISCONNECTED, reason) + + def request_state(self, state: ConnectionState, force=False) -> None: + log.debug(f'ConnectionManager.request_state(): state = {state}') + + if not force and state == self.state: + return + + if state == ConnectionState.CONNECTING and self.__state == ConnectionState.CONNECTED: + return + + if state == ConnectionState.CLOSING and self.__state == ConnectionState.CLOSED: + return + + if state == ConnectionState.CONNECTING and self.__state in (ConnectionState.CLOSED, + ConnectionState.FAILED): + self.ably.channels._initialize_channels() + + if not force: + self.enact_state_change(state) + + if state == ConnectionState.CONNECTING: + self.start_connect() + + if state == ConnectionState.CLOSING: + asyncio.create_task(self.close_impl()) + + def start_connect(self) -> None: + self.start_suspend_timer() + self.start_transition_timer(ConnectionState.CONNECTING) + self.connect_base_task = asyncio.create_task(self.connect_base()) + + def connect_with_fallback_hosts(self, fallback_hosts: list) -> Optional[Exception]: + for host in fallback_hosts: + try: + if self.check_connection(): + self.try_host(host) + return + else: + message = "Unable to connect, network unreachable" + log.exception(message) + exception = AblyException(message, status_code=404, code=80003) + self.notify_state(self.__fail_state, exception) + return + except Exception as exc: + exception = exc + log.exception(f'Connection to {host} failed, reason={exception}') + log.exception("No more fallback hosts to try") + return exception + + def connect_base(self) -> None: + fallback_hosts = self.__fallback_hosts + primary_host = self.options.get_realtime_host() + try: + self.try_host(primary_host) + return + except Exception as exception: + log.exception(f'Connection to {primary_host} failed, reason={exception}') + if len(fallback_hosts) > 0: + log.info("Attempting connection to fallback host(s)") + resp = self.connect_with_fallback_hosts(fallback_hosts) + if not resp: + return + exception = resp + self.notify_state(self.__fail_state, reason=exception) + + def try_host(self, host) -> None: + try: + params = self.__get_transport_params() + except AblyException as e: + self.on_error_from_authorize(e) + return + self.transport = WebSocketTransport(self, host, params) + self._emit('transport.pending', self.transport) + self.transport.connect() + + future = asyncio.Future() + + def on_transport_connected(): + log.debug('ConnectionManager.try_a_host(): transport connected') + if self.transport: + self.transport.off('failed', on_transport_failed) + if not future.done(): + future.set_result(None) + + def on_transport_failed(exception): + log.info('ConnectionManager.try_a_host(): transport failed') + if self.transport: + self.transport.off('connected', on_transport_connected) + self.transport.dispose() + future.set_exception(exception) + + self.transport.once('connected', on_transport_connected) + self.transport.once('failed', on_transport_failed) + # Fix asyncio CancelledError in python 3.7 + try: + future + except asyncio.CancelledError: + return + + def notify_state(self, state: ConnectionState, reason: Optional[AblyException] = None, + retry_immediately: Optional[bool] = None) -> None: + # RTN15a + retry_immediately = (retry_immediately is not False) and ( + state == ConnectionState.DISCONNECTED and self.__state == ConnectionState.CONNECTED) + + log.debug( + f'ConnectionManager.notify_state(): new state: {state}' + + ('; will retry immediately' if retry_immediately else '') + ) + + if state == self.__state: + return + + self.cancel_transition_timer() + self.check_suspend_timer(state) + + if retry_immediately: + self.options.loop.call_soon(self.request_state, ConnectionState.CONNECTING) + elif state == ConnectionState.DISCONNECTED: + self.start_retry_timer(self.options.disconnected_retry_timeout) + elif state == ConnectionState.SUSPENDED: + self.start_retry_timer(self.options.suspended_retry_timeout) + + if (state == ConnectionState.DISCONNECTED and not retry_immediately) or state == ConnectionState.SUSPENDED: + self.disconnect_transport() + + self.enact_state_change(state, reason) + + if state == ConnectionState.CONNECTED: + self.send_queued_messages() + elif state in ( + ConnectionState.CLOSING, + ConnectionState.CLOSED, + ConnectionState.SUSPENDED, + ConnectionState.FAILED, + ): + self.fail_queued_messages(reason) + self.ably.channels._propagate_connection_interruption(state, reason) + + def start_transition_timer(self, state: ConnectionState, fail_state: Optional[ConnectionState] = None) -> None: + log.debug(f'ConnectionManager.start_transition_timer(): transition state = {state}') + + if self.transition_timer: + log.debug('ConnectionManager.start_transition_timer(): clearing already-running timer') + self.transition_timer.cancel() + + if fail_state is None: + fail_state = self.__fail_state if state != ConnectionState.CLOSING else ConnectionState.CLOSED + + timeout = self.options.realtime_request_timeout + + def on_transition_timer_expire(): + if self.transition_timer: + self.transition_timer = None + log.info(f'ConnectionManager {state} timer expired, notifying new state: {fail_state}') + self.notify_state( + fail_state, + AblyException("Connection cancelled due to request timeout", 504, 50003) + ) + + log.debug(f'ConnectionManager.start_transition_timer(): setting timer for {timeout}ms') + + self.transition_timer = Timer(timeout, on_transition_timer_expire) + + def cancel_transition_timer(self): + log.debug('ConnectionManager.cancel_transition_timer()') + if self.transition_timer: + self.transition_timer.cancel() + self.transition_timer = None + + def start_suspend_timer(self) -> None: + log.debug('ConnectionManager.start_suspend_timer()') + if self.suspend_timer: + return + + def on_suspend_timer_expire() -> None: + if self.suspend_timer: + self.suspend_timer = None + log.info('ConnectionManager suspend timer expired, requesting new state: suspended') + self.notify_state( + ConnectionState.SUSPENDED, + AblyException("Connection to server unavailable", 400, 80002) + ) + self.__fail_state = ConnectionState.SUSPENDED + self.__connection_details = None + + self.suspend_timer = Timer(Defaults.connection_state_ttl, on_suspend_timer_expire) + + def check_suspend_timer(self, state: ConnectionState) -> None: + if state not in ( + ConnectionState.CONNECTING, + ConnectionState.DISCONNECTED, + ConnectionState.SUSPENDED, + ): + self.cancel_suspend_timer() + + def cancel_suspend_timer(self) -> None: + log.debug('ConnectionManager.cancel_suspend_timer()') + self.__fail_state = ConnectionState.DISCONNECTED + if self.suspend_timer: + self.suspend_timer.cancel() + self.suspend_timer = None + + def start_retry_timer(self, interval: int) -> None: + def on_retry_timeout(): + log.info('ConnectionManager retry timer expired, retrying') + self.retry_timer = None + self.request_state(ConnectionState.CONNECTING) + + self.retry_timer = Timer(interval, on_retry_timeout) + + def cancel_retry_timer(self) -> None: + if self.retry_timer: + self.retry_timer.cancel() + self.retry_timer = None + + def disconnect_transport(self) -> None: + log.info('ConnectionManager.disconnect_transport()') + if self.transport: + self.disconnect_transport_task = asyncio.create_task(self.transport.dispose()) + + def on_auth_updated(self, token_details: TokenDetails): + log.info(f"ConnectionManager.on_auth_updated(): state = {self.state}") + if self.state == ConnectionState.CONNECTED: + auth_message = { + "action": ProtocolMessageAction.AUTH, + "auth": { + "accessToken": token_details.token + } + } + self.send_protocol_message(auth_message) + + state_change = self.once_async() + + if state_change.current == ConnectionState.CONNECTED: + return + elif state_change.current == ConnectionState.FAILED: + raise state_change.reason + elif self.state == ConnectionState.CONNECTING: + if self.connect_base_task and not self.connect_base_task.done(): + self.connect_base_task.cancel() + if self.transport: + self.transport.dispose() + if self.state != ConnectionState.CONNECTED: + future = asyncio.Future() + + def on_state_change(state_change: ConnectionStateChange) -> None: + if state_change.current == ConnectionState.CONNECTED: + self.off('connectionstate', on_state_change) + future.set_result(token_details) + if state_change.current in ( + ConnectionState.CLOSED, + ConnectionState.FAILED, + ConnectionState.SUSPENDED + ): + self.off('connectionstate', on_state_change) + future.set_exception(state_change.reason or self.get_state_error()) + + self.on('connectionstate', on_state_change) + + if self.state == ConnectionState.CONNECTING: + self.start_connect() + else: + self.request_state(ConnectionState.CONNECTING) + + return future + + @property + def ably(self): + return self.__ably + + @property + def state(self) -> ConnectionState: + return self.__state + + @property + def connection_details(self) -> Optional[ConnectionDetails]: + return self.__connection_details diff --git a/ably/sync/realtime/realtime.py b/ably/sync/realtime/realtime.py new file mode 100644 index 00000000..51028a08 --- /dev/null +++ b/ably/sync/realtime/realtime.py @@ -0,0 +1,140 @@ +import logging +import asyncio +from typing import Optional +from ably.sync.realtime.realtime_channel import Channels +from ably.sync.realtime.connection import Connection, ConnectionState +from ably.sync.rest.rest import AblyRest + + +log = logging.getLogger(__name__) + + +class AblyRealtime(AblyRest): + """ + Ably Realtime Client + + Attributes + ---------- + loop: AbstractEventLoop + asyncio running event loop + auth: Auth + authentication object + options: Options + auth options object + connection: Connection + realtime connection object + channels: Channels + realtime channel object + + Methods + ------- + connect() + Establishes the realtime connection + close() + Closes the realtime connection + """ + + def __init__(self, key: Optional[str] = None, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs): + """Constructs a RealtimeClient object using an Ably API key. + + Parameters + ---------- + key: str + A valid ably API key string + loop: AbstractEventLoop, optional + asyncio running event loop + auto_connect: bool + When true, the client connects to Ably as soon as it is instantiated. + You can set this to false and explicitly connect to Ably using the + connect() method. The default is true. + **kwargs: client options + realtime_host: str + Enables a non-default Ably host to be specified for realtime connections. + For development environments only. The default value is realtime.ably.io. + environment: str + Enables a custom environment to be used with the Ably service. Defaults to `production` + realtime_request_timeout: float + Timeout (in milliseconds) for the wait of acknowledgement for operations performed via a realtime + connection. Operations include establishing a connection with Ably, or sending a HEARTBEAT, + CONNECT, ATTACH, DETACH or CLOSE request. The default is 10 seconds(10000 milliseconds). + disconnected_retry_timeout: float + If the connection is still in the DISCONNECTED state after this delay, the client library will + attempt to reconnect automatically. The default is 15 seconds. + channel_retry_timeout: float + When a channel becomes SUSPENDED following a server initiated DETACHED, after this delay, if the + channel is still SUSPENDED and the connection is in CONNECTED, the client library will attempt to + re-attach the channel automatically. The default is 15 seconds. + fallback_hosts: list[str] + An array of fallback hosts to be used in the case of an error necessitating the use of an + alternative host. If you have been provided a set of custom fallback hosts by Ably, please specify + them here. + connection_state_ttl: float + The duration that Ably will persist the connection state for when a Realtime client is abruptly + disconnected. + suspended_retry_timeout: float + When the connection enters the SUSPENDED state, after this delay, if the state is still SUSPENDED, + the client library attempts to reconnect automatically. The default is 30 seconds. + connectivity_check_url: string + Override the URL used by the realtime client to check if the internet is available. + In the event of a failure to connect to the primary endpoint, the client will send a + GET request to this URL to check if the internet is available. If this request returns + a success response the client will attempt to connect to a fallback host. + Raises + ------ + ValueError + If no authentication key is not provided + """ + + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + log.warning('Realtime client created outside event loop') + + self._is_realtime: bool = True + + # RTC1 + super().__init__(key, loop=loop, **kwargs) + + self.key = key + self.__connection = Connection(self) + self.__channels = Channels(self) + + # RTN3 + if self.options.auto_connect: + self.connection.connection_manager.request_state(ConnectionState.CONNECTING, force=True) + + # RTC15 + def connect(self) -> None: + """Establishes a realtime connection. + + Explicitly calling connect() is unnecessary unless the autoConnect attribute of the ClientOptions object + is false. Unless already connected or connecting, this method causes the connection to open, entering the + CONNECTING state. + """ + log.info('Realtime.connect() called') + # RTC15a + self.connection.connect() + + # RTC16 + def close(self) -> None: + """Causes the connection to close, entering the closing state. + Once closed, the library will not attempt to re-establish the + connection without an explicit call to connect() + """ + log.info('Realtime.close() called') + # RTC16a + self.connection.close() + super().close() + + # RTC2 + @property + def connection(self) -> Connection: + """Returns the realtime connection object""" + return self.__connection + + # RTC3, RTS1 + @property + def channels(self) -> Channels: + """Returns the realtime channel object""" + return self.__channels diff --git a/ably/sync/realtime/realtime_channel.py b/ably/sync/realtime/realtime_channel.py new file mode 100644 index 00000000..5ed99393 --- /dev/null +++ b/ably/sync/realtime/realtime_channel.py @@ -0,0 +1,553 @@ +from __future__ import annotations +import asyncio +import logging +from typing import Optional, TYPE_CHECKING +from ably.sync.realtime.connection import ConnectionState +from ably.sync.transport.websockettransport import ProtocolMessageAction +from ably.sync.rest.channel import Channel, Channels as RestChannels +from ably.sync.types.channelstate import ChannelState, ChannelStateChange +from ably.sync.types.flags import Flag, has_flag +from ably.sync.types.message import Message +from ably.sync.util.eventemitter import EventEmitter +from ably.sync.util.exceptions import AblyException +from ably.sync.util.helper import Timer, is_callable_or_coroutine + +if TYPE_CHECKING: + from ably.sync.realtime.realtime import AblyRealtime + +log = logging.getLogger(__name__) + + +class RealtimeChannel(EventEmitter, Channel): + """ + Ably Realtime Channel + + Attributes + ---------- + name: str + Channel name + state: str + Channel state + error_reason: AblyException + An AblyException instance describing the last error which occurred on the channel, if any. + + Methods + ------- + attach() + Attach to channel + detach() + Detach from channel + subscribe(*args) + Subscribe to messages on a channel + unsubscribe(*args) + Unsubscribe to messages from a channel + """ + + def __init__(self, realtime: AblyRealtime, name: str): + EventEmitter.__init__(self) + self.__name = name + self.__realtime = realtime + self.__state = ChannelState.INITIALIZED + self.__message_emitter = EventEmitter() + self.__state_timer: Optional[Timer] = None + self.__attach_resume = False + self.__channel_serial: Optional[str] = None + self.__retry_timer: Optional[Timer] = None + self.__error_reason: Optional[AblyException] = None + + # Used to listen to state changes internally, if we use the public event emitter interface then internals + # will be disrupted if the user called .off() to remove all listeners + self.__internal_state_emitter = EventEmitter() + + Channel.__init__(self, realtime, name, {}) + + # RTL4 + def attach(self) -> None: + """Attach to channel + + Attach to this channel ensuring the channel is created in the Ably system and all messages published + on the channel are received by any channel listeners registered using subscribe + + Raises + ------ + AblyException + If unable to attach channel + """ + + log.info(f'RealtimeChannel.attach() called, channel = {self.name}') + + # RTL4a - if channel is attached do nothing + if self.state == ChannelState.ATTACHED: + return + + self.__error_reason = None + + # RTL4b + if self.__realtime.connection.state not in [ + ConnectionState.CONNECTING, + ConnectionState.CONNECTED, + ConnectionState.DISCONNECTED + ]: + raise AblyException( + message=f"Unable to attach; channel state = {self.state}", + code=90001, + status_code=400 + ) + + if self.state != ChannelState.ATTACHING: + self._request_state(ChannelState.ATTACHING) + + state_change = self.__internal_state_emitter.once_async() + + if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED): + raise state_change.reason + + def _attach_impl(self): + log.debug("RealtimeChannel.attach_impl(): sending ATTACH protocol message") + + # RTL4c + attach_msg = { + "action": ProtocolMessageAction.ATTACH, + "channel": self.name, + } + + if self.__attach_resume: + attach_msg["flags"] = Flag.ATTACH_RESUME + if self.__channel_serial: + attach_msg["channelSerial"] = self.__channel_serial + + self._send_message(attach_msg) + + # RTL5 + def detach(self) -> None: + """Detach from channel + + Any resulting channel state change is emitted to any listeners registered + Once all clients globally have detached from the channel, the channel will be released + in the Ably service within two minutes. + + Raises + ------ + AblyException + If unable to detach channel + """ + + log.info(f'RealtimeChannel.detach() called, channel = {self.name}') + + # RTL5g, RTL5b - raise exception if state invalid + if self.__realtime.connection.state in [ConnectionState.CLOSING, ConnectionState.FAILED]: + raise AblyException( + message=f"Unable to detach; channel state = {self.state}", + code=90001, + status_code=400 + ) + + # RTL5a - if channel already detached do nothing + if self.state in [ChannelState.INITIALIZED, ChannelState.DETACHED]: + return + + if self.state == ChannelState.SUSPENDED: + self._notify_state(ChannelState.DETACHED) + return + elif self.state == ChannelState.FAILED: + raise AblyException("Unable to detach; channel state = failed", 90001, 400) + else: + self._request_state(ChannelState.DETACHING) + + # RTL5h - wait for pending connection + if self.__realtime.connection.state == ConnectionState.CONNECTING: + self.__realtime.connect() + + state_change = self.__internal_state_emitter.once_async() + new_state = state_change.current + + if new_state == ChannelState.DETACHED: + return + elif new_state == ChannelState.ATTACHING: + raise AblyException("Detach request superseded by a subsequent attach request", 90000, 409) + else: + raise state_change.reason + + def _detach_impl(self) -> None: + log.debug("RealtimeChannel.detach_impl(): sending DETACH protocol message") + + # RTL5d + detach_msg = { + "action": ProtocolMessageAction.DETACH, + "channel": self.__name, + } + + self._send_message(detach_msg) + + # RTL7 + def subscribe(self, *args) -> None: + """Subscribe to a channel + + Registers a listener for messages on the channel. + The caller supplies a listener function, which is called + each time one or more messages arrives on the channel. + + The function resolves once the channel is attached. + + Parameters + ---------- + *args: event, listener + Subscribe event and listener + + arg1(event): str, optional + Subscribe to messages with the given event name + + arg2(listener): callable + Subscribe to all messages on the channel + + When no event is provided, arg1 is used as the listener. + + Raises + ------ + AblyException + If unable to subscribe to a channel due to invalid connection state + ValueError + If no valid subscribe arguments are passed + """ + if isinstance(args[0], str): + event = args[0] + if not args[1]: + raise ValueError("channel.subscribe called without listener") + if not is_callable_or_coroutine(args[1]): + raise ValueError("subscribe listener must be function or coroutine function") + listener = args[1] + elif is_callable_or_coroutine(args[0]): + listener = args[0] + event = None + else: + raise ValueError('invalid subscribe arguments') + + log.info(f'RealtimeChannel.subscribe called, channel = {self.name}, event = {event}') + + if event is not None: + # RTL7b + self.__message_emitter.on(event, listener) + else: + # RTL7a + self.__message_emitter.on(listener) + + # RTL7c + self.attach() + + # RTL8 + def unsubscribe(self, *args) -> None: + """Unsubscribe from a channel + + Deregister the given listener for (for any/all event names). + This removes an earlier event-specific subscription. + + Parameters + ---------- + *args: event, listener + Unsubscribe event and listener + + arg1(event): str, optional + Unsubscribe to messages with the given event name + + arg2(listener): callable + Unsubscribe to all messages on the channel + + When no event is provided, arg1 is used as the listener. + + Raises + ------ + ValueError + If no valid unsubscribe arguments are passed, no listener or listener is not a function + or coroutine + """ + if len(args) == 0: + event = None + listener = None + elif isinstance(args[0], str): + event = args[0] + if not args[1]: + raise ValueError("channel.unsubscribe called without listener") + if not is_callable_or_coroutine(args[1]): + raise ValueError("unsubscribe listener must be a function or coroutine function") + listener = args[1] + elif is_callable_or_coroutine(args[0]): + listener = args[0] + event = None + else: + raise ValueError('invalid unsubscribe arguments') + + log.info(f'RealtimeChannel.unsubscribe called, channel = {self.name}, event = {event}') + + if listener is None: + # RTL8c + self.__message_emitter.off() + elif event is not None: + # RTL8b + self.__message_emitter.off(event, listener) + else: + # RTL8a + self.__message_emitter.off(listener) + + def _on_message(self, proto_msg: dict) -> None: + action = proto_msg.get('action') + # RTL4c1 + channel_serial = proto_msg.get('channelSerial') + if channel_serial: + self.__channel_serial = channel_serial + # TM2a, TM2c, TM2f + Message.update_inner_message_fields(proto_msg) + + if action == ProtocolMessageAction.ATTACHED: + flags = proto_msg.get('flags') + error = proto_msg.get("error") + exception = None + resumed = False + + if error: + exception = AblyException.from_dict(error) + + if flags: + resumed = has_flag(flags, Flag.RESUMED) + + # RTL12 + if self.state == ChannelState.ATTACHED: + if not resumed: + state_change = ChannelStateChange(self.state, ChannelState.ATTACHED, resumed, exception) + self._emit("update", state_change) + elif self.state == ChannelState.ATTACHING: + self._notify_state(ChannelState.ATTACHED, resumed=resumed) + else: + log.warn("RealtimeChannel._on_message(): ATTACHED received while not attaching") + elif action == ProtocolMessageAction.DETACHED: + if self.state == ChannelState.DETACHING: + self._notify_state(ChannelState.DETACHED) + elif self.state == ChannelState.ATTACHING: + self._notify_state(ChannelState.SUSPENDED) + else: + self._request_state(ChannelState.ATTACHING) + elif action == ProtocolMessageAction.MESSAGE: + messages = Message.from_encoded_array(proto_msg.get('messages')) + for message in messages: + self.__message_emitter._emit(message.name, message) + elif action == ProtocolMessageAction.ERROR: + error = AblyException.from_dict(proto_msg.get('error')) + self._notify_state(ChannelState.FAILED, reason=error) + + def _request_state(self, state: ChannelState) -> None: + log.debug(f'RealtimeChannel._request_state(): state = {state}') + self._notify_state(state) + self._check_pending_state() + + def _notify_state(self, state: ChannelState, reason: Optional[AblyException] = None, + resumed: bool = False) -> None: + log.debug(f'RealtimeChannel._notify_state(): state = {state}') + + self.__clear_state_timer() + + if state == self.state: + return + + if reason is not None: + self.__error_reason = reason + + if state == ChannelState.INITIALIZED: + self.__error_reason = None + + if state == ChannelState.SUSPENDED and self.ably.connection.state == ConnectionState.CONNECTED: + self.__start_retry_timer() + else: + self.__cancel_retry_timer() + + # RTL4j1 + if state == ChannelState.ATTACHED: + self.__attach_resume = True + if state in (ChannelState.DETACHING, ChannelState.FAILED): + self.__attach_resume = False + + # RTP5a1 + if state in (ChannelState.DETACHED, ChannelState.SUSPENDED, ChannelState.FAILED): + self.__channel_serial = None + + state_change = ChannelStateChange(self.__state, state, resumed, reason=reason) + + self.__state = state + self._emit(state, state_change) + self.__internal_state_emitter._emit(state, state_change) + + def _send_message(self, msg: dict) -> None: + asyncio.create_task(self.__realtime.connection.connection_manager.send_protocol_message(msg)) + + def _check_pending_state(self): + connection_state = self.__realtime.connection.connection_manager.state + + if connection_state is not ConnectionState.CONNECTED: + log.debug(f"RealtimeChannel._check_pending_state(): connection state = {connection_state}") + return + + if self.state == ChannelState.ATTACHING: + self.__start_state_timer() + self._attach_impl() + elif self.state == ChannelState.DETACHING: + self.__start_state_timer() + self._detach_impl() + + def __start_state_timer(self) -> None: + if not self.__state_timer: + def on_timeout() -> None: + log.debug('RealtimeChannel.start_state_timer(): timer expired') + self.__state_timer = None + self.__timeout_pending_state() + + self.__state_timer = Timer(self.__realtime.options.realtime_request_timeout, on_timeout) + + def __clear_state_timer(self) -> None: + if self.__state_timer: + self.__state_timer.cancel() + self.__state_timer = None + + def __timeout_pending_state(self) -> None: + if self.state == ChannelState.ATTACHING: + self._notify_state( + ChannelState.SUSPENDED, reason=AblyException("Channel attach timed out", 408, 90007)) + elif self.state == ChannelState.DETACHING: + self._notify_state(ChannelState.ATTACHED, reason=AblyException("Channel detach timed out", 408, 90007)) + else: + self._check_pending_state() + + def __start_retry_timer(self) -> None: + if self.__retry_timer: + return + + self.__retry_timer = Timer(self.ably.options.channel_retry_timeout, self.__on_retry_timer_expire) + + def __cancel_retry_timer(self) -> None: + if self.__retry_timer: + self.__retry_timer.cancel() + self.__retry_timer = None + + def __on_retry_timer_expire(self) -> None: + if self.state == ChannelState.SUSPENDED and self.ably.connection.state == ConnectionState.CONNECTED: + self.__retry_timer = None + log.info("RealtimeChannel retry timer expired, attempting a new attach") + self._request_state(ChannelState.ATTACHING) + + # RTL23 + @property + def name(self) -> str: + """Returns channel name""" + return self.__name + + # RTL2b + @property + def state(self) -> ChannelState: + """Returns channel state""" + return self.__state + + @state.setter + def state(self, state: ChannelState) -> None: + self.__state = state + + # RTL24 + @property + def error_reason(self) -> Optional[AblyException]: + """An AblyException instance describing the last error which occurred on the channel, if any.""" + return self.__error_reason + + +class Channels(RestChannels): + """Creates and destroys RealtimeChannel objects. + + Methods + ------- + get(name) + Gets a channel + release(name) + Releases a channel + """ + + # RTS3 + def get(self, name: str) -> RealtimeChannel: + """Creates a new RealtimeChannel object, or returns the existing channel object. + + Parameters + ---------- + + name: str + Channel name + """ + if name not in self.__all: + channel = self.__all[name] = RealtimeChannel(self.__ably, name) + else: + channel = self.__all[name] + return channel + + # RTS4 + def release(self, name: str) -> None: + """Releases a RealtimeChannel object, deleting it, and enabling it to be garbage collected + + It also removes any listeners associated with the channel. + To release a channel, the channel state must be INITIALIZED, DETACHED, or FAILED. + + + Parameters + ---------- + name: str + Channel name + """ + if name not in self.__all: + return + del self.__all[name] + + def _on_channel_message(self, msg: dict) -> None: + channel_name = msg.get('channel') + if not channel_name: + log.error( + 'Channels.on_channel_message()', + f'received event without channel, action = {msg.get("action")}' + ) + return + + channel = self.__all[channel_name] + if not channel: + log.warning( + 'Channels.on_channel_message()', + f'receieved event for non-existent channel: {channel_name}' + ) + return + + channel._on_message(msg) + + def _propagate_connection_interruption(self, state: ConnectionState, reason: Optional[AblyException]) -> None: + from_channel_states = ( + ChannelState.ATTACHING, + ChannelState.ATTACHED, + ChannelState.DETACHING, + ChannelState.SUSPENDED, + ) + + connection_to_channel_state = { + ConnectionState.CLOSING: ChannelState.DETACHED, + ConnectionState.CLOSED: ChannelState.DETACHED, + ConnectionState.FAILED: ChannelState.FAILED, + ConnectionState.SUSPENDED: ChannelState.SUSPENDED, + } + + for channel_name in self.__all: + channel = self.__all[channel_name] + if channel.state in from_channel_states: + channel._notify_state(connection_to_channel_state[state], reason) + + def _on_connected(self) -> None: + for channel_name in self.__all: + channel = self.__all[channel_name] + if channel.state == ChannelState.ATTACHING or channel.state == ChannelState.DETACHING: + channel._check_pending_state() + elif channel.state == ChannelState.SUSPENDED: + asyncio.create_task(channel.attach()) + elif channel.state == ChannelState.ATTACHED: + channel._request_state(ChannelState.ATTACHING) + + def _initialize_channels(self) -> None: + for channel_name in self.__all: + channel = self.__all[channel_name] + channel._request_state(ChannelState.INITIALIZED) diff --git a/ably/sync/rest/__init__.py b/ably/sync/rest/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/rest/auth.py b/ably/sync/rest/auth.py new file mode 100644 index 00000000..a35e1fc2 --- /dev/null +++ b/ably/sync/rest/auth.py @@ -0,0 +1,425 @@ +from __future__ import annotations +import base64 +from datetime import timedelta +import logging +import time +from typing import Optional, TYPE_CHECKING, Union +import uuid +import httpx + +from ably.sync.types.options import Options +if TYPE_CHECKING: + from ably.sync.rest.rest import AblyRest + from ably.sync.realtime.realtime import AblyRealtime + +from ably.sync.types.capability import Capability +from ably.sync.types.tokendetails import TokenDetails +from ably.sync.types.tokenrequest import TokenRequest +from ably.sync.util.exceptions import AblyAuthException, AblyException, IncompatibleClientIdException + +__all__ = ["Auth"] + +log = logging.getLogger(__name__) + + +class Auth: + + class Method: + BASIC = "BASIC" + TOKEN = "TOKEN" + + def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options): + self.__ably = ably + self.__auth_options = options + + if not self.ably._is_realtime: + self.__client_id = options.client_id + if not self.__client_id and options.token_details: + self.__client_id = options.token_details.client_id + else: + self.__client_id = None + self.__client_id_validated: bool = False + + self.__basic_credentials: Optional[str] = None + self.__auth_params: Optional[dict] = None + self.__token_details: Optional[TokenDetails] = None + self.__time_offset: Optional[int] = None + + must_use_token_auth = options.use_token_auth is True + must_not_use_token_auth = options.use_token_auth is False + can_use_basic_auth = options.key_secret is not None + if not must_use_token_auth and can_use_basic_auth: + # We have the key, no need to authenticate the client + # default to using basic auth + log.debug("anonymous, using basic auth") + self.__auth_mechanism = Auth.Method.BASIC + basic_key = "%s:%s" % (options.key_name, options.key_secret) + basic_key = base64.b64encode(basic_key.encode('utf-8')) + self.__basic_credentials = basic_key.decode('ascii') + return + elif must_not_use_token_auth and not can_use_basic_auth: + raise ValueError('If use_token_auth is False you must provide a key') + + # Using token auth + self.__auth_mechanism = Auth.Method.TOKEN + + if options.token_details: + self.__token_details = options.token_details + elif options.auth_token: + self.__token_details = TokenDetails(token=options.auth_token) + else: + self.__token_details = None + + if options.auth_callback: + log.debug("using token auth with auth_callback") + elif options.auth_url: + log.debug("using token auth with auth_url") + elif options.key_secret: + log.debug("using token auth with client-side signing") + elif options.auth_token: + log.debug("using token auth with supplied token only") + elif options.token_details: + log.debug("using token auth with supplied token_details") + else: + raise ValueError("Can't authenticate via token, must provide " + "auth_callback, auth_url, key, token or a TokenDetail") + + def get_auth_transport_param(self): + auth_credentials = {} + if self.auth_options.client_id: + auth_credentials["client_id"] = self.auth_options.client_id + if self.__auth_mechanism == Auth.Method.BASIC: + key_name = self.__auth_options.key_name + key_secret = self.__auth_options.key_secret + auth_credentials["key"] = f"{key_name}:{key_secret}" + elif self.__auth_mechanism == Auth.Method.TOKEN: + token_details = self._ensure_valid_auth_credentials() + auth_credentials["accessToken"] = token_details.token + return auth_credentials + + def __authorize_when_necessary(self, token_params=None, auth_options=None, force=False): + token_details = self._ensure_valid_auth_credentials(token_params, auth_options, force) + + if self.ably._is_realtime: + self.ably.connection.connection_manager.on_auth_updated(token_details) + + return token_details + + def _ensure_valid_auth_credentials(self, token_params=None, auth_options=None, force=False): + self.__auth_mechanism = Auth.Method.TOKEN + if token_params is None: + token_params = dict(self.auth_options.default_token_params) + else: + self.auth_options.default_token_params = dict(token_params) + self.auth_options.default_token_params.pop('timestamp', None) + + if auth_options is not None: + self.auth_options.replace(auth_options) + auth_options = dict(self.auth_options.auth_options) + if self.client_id is not None: + token_params['client_id'] = self.client_id + + token_details = self.__token_details + if not force and not self.token_details_has_expired(): + log.debug("using cached token; expires = %d", + token_details.expires) + return token_details + + self.__token_details = self.request_token(token_params, **auth_options) + self._configure_client_id(self.__token_details.client_id) + + return self.__token_details + + def token_details_has_expired(self): + token_details = self.__token_details + if token_details is None: + return True + + if not self.__time_offset: + return False + + expires = token_details.expires + if expires is None: + return False + + timestamp = self._timestamp() + if self.__time_offset: + timestamp += self.__time_offset + + return expires < timestamp + token_details.TOKEN_EXPIRY_BUFFER + + def authorize(self, token_params: Optional[dict] = None, auth_options=None): + return self.__authorize_when_necessary(token_params, auth_options, force=True) + + def request_token(self, token_params: Optional[dict] = None, + # auth_options + key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None, + auth_url: Optional[str] = None, auth_method: Optional[str] = None, + auth_headers: Optional[dict] = None, auth_params: Optional[dict] = None, + query_time=None): + token_params = token_params or {} + token_params = dict(self.auth_options.default_token_params, + **token_params) + key_name = key_name or self.auth_options.key_name + key_secret = key_secret or self.auth_options.key_secret + + log.debug("Auth callback: %s" % auth_callback) + log.debug("Auth options: %s" % self.auth_options) + if query_time is None: + query_time = self.auth_options.query_time + query_time = bool(query_time) + auth_callback = auth_callback or self.auth_options.auth_callback + auth_url = auth_url or self.auth_options.auth_url + + auth_params = auth_params or self.auth_options.auth_params or {} + + auth_method = (auth_method or self.auth_options.auth_method).upper() + + auth_headers = auth_headers or self.auth_options.auth_headers or {} + + log.debug("Token Params: %s" % token_params) + if auth_callback: + log.debug("using token auth with authCallback") + try: + token_request = auth_callback(token_params) + except Exception as e: + raise AblyException("auth_callback raised an exception", 401, 40170, cause=e) + elif auth_url: + log.debug("using token auth with authUrl") + + token_request = self.token_request_from_auth_url( + auth_method, auth_url, token_params, auth_headers, auth_params) + elif key_name is not None and key_secret is not None: + token_request = self.create_token_request( + token_params, key_name=key_name, key_secret=key_secret, + query_time=query_time) + else: + msg = "Need a new token but auth_options does not include a way to request one" + log.exception(msg) + raise AblyAuthException(msg, 403, 40171) + if isinstance(token_request, TokenDetails): + return token_request + elif isinstance(token_request, dict) and 'issued' in token_request: + return TokenDetails.from_dict(token_request) + elif isinstance(token_request, dict): + try: + token_request = TokenRequest.from_json(token_request) + except TypeError as e: + msg = "Expected token request callback to call back with a token string, token request object, or \ + token details object" + raise AblyAuthException(msg, 401, 40170, cause=e) + elif isinstance(token_request, str): + if len(token_request) == 0: + raise AblyAuthException("Token string is empty", 401, 4017) + return TokenDetails(token=token_request) + elif token_request is None: + raise AblyAuthException("Token string was None", 401, 40170) + + token_path = "/keys/%s/requestToken" % token_request.key_name + + response = self.ably.http.post( + token_path, + headers=auth_headers, + body=token_request.to_dict(), + skip_auth=True + ) + + AblyException.raise_for_response(response) + response_dict = response.to_native() + log.debug("Token: %s" % str(response_dict.get("token"))) + return TokenDetails.from_dict(response_dict) + + def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None, + key_secret: Optional[str] = None, query_time=None): + token_params = token_params or {} + token_request = {} + + key_name = key_name or self.auth_options.key_name + key_secret = key_secret or self.auth_options.key_secret + if not key_name or not key_secret: + log.debug('key_name or key_secret blank') + raise AblyException("No key specified: no means to generate a token", 401, 40101) + + token_request['key_name'] = key_name + if token_params.get('timestamp'): + token_request['timestamp'] = token_params['timestamp'] + else: + if query_time is None: + query_time = self.auth_options.query_time + + if query_time: + if self.__time_offset is None: + server_time = self.ably.time() + local_time = self._timestamp() + self.__time_offset = server_time - local_time + token_request['timestamp'] = server_time + else: + local_time = self._timestamp() + token_request['timestamp'] = local_time + self.__time_offset + else: + token_request['timestamp'] = self._timestamp() + + token_request['timestamp'] = int(token_request['timestamp']) + + ttl = token_params.get('ttl') + if ttl is not None: + if isinstance(ttl, timedelta): + ttl = ttl.total_seconds() * 1000 + token_request['ttl'] = int(ttl) + + capability = token_params.get('capability') + if capability is not None: + token_request['capability'] = str(Capability(capability)) + + token_request["client_id"] = ( + token_params.get('client_id') or self.client_id) + + # Note: There is no expectation that the client + # specifies the nonce; this is done by the library + # However, this can be overridden by the client + # simply for testing purposes + token_request["nonce"] = token_params.get('nonce') or self._random_nonce() + + token_req = TokenRequest(**token_request) + + if token_params.get('mac') is None: + # Note: There is no expectation that the client + # specifies the mac; this is done by the library + # However, this can be overridden by the client + # simply for testing purposes. + token_req.sign_request(key_secret.encode('utf8')) + else: + token_req.mac = token_params['mac'] + + return token_req + + @property + def ably(self): + return self.__ably + + @property + def auth_mechanism(self): + return self.__auth_mechanism + + @property + def auth_options(self): + return self.__auth_options + + @property + def auth_params(self): + return self.__auth_params + + @property + def basic_credentials(self): + return self.__basic_credentials + + @property + def token_credentials(self): + if self.__token_details: + token = self.__token_details.token + token_key = base64.b64encode(token.encode('utf-8')) + return token_key.decode('ascii') + + @property + def token_details(self): + return self.__token_details + + @property + def client_id(self): + return self.__client_id + + @property + def time_offset(self): + return self.__time_offset + + def _configure_client_id(self, new_client_id): + log.debug("Auth._configure_client_id(): new client_id = %s", new_client_id) + original_client_id = self.client_id or self.auth_options.client_id + + # If new client ID from Ably is a wildcard, but preconfigured clientId is set, + # then keep the existing clientId + if original_client_id != '*' and new_client_id == '*': + self.__client_id_validated = True + self.__client_id = original_client_id + return + + # If client_id is defined and not a wildcard, prevent it changing, this is not supported + if original_client_id is not None and original_client_id != '*' and new_client_id != original_client_id: + raise IncompatibleClientIdException( + "Client ID is immutable once configured for a client. " + "Client ID cannot be changed to '{}'".format(new_client_id), 400, 40102) + + self.__client_id_validated = True + self.__client_id = new_client_id + + def can_assume_client_id(self, assumed_client_id): + original_client_id = self.client_id or self.auth_options.client_id + + if self.__client_id_validated: + return self.client_id == '*' or self.client_id == assumed_client_id + elif original_client_id is None or original_client_id == '*': + return True # client ID is unknown + else: + return original_client_id == assumed_client_id + + def _get_auth_headers(self): + if self.__auth_mechanism == Auth.Method.BASIC: + # RSA7e2 + if self.client_id: + return { + 'Authorization': 'Basic %s' % self.basic_credentials, + 'X-Ably-ClientId': base64.b64encode(self.client_id.encode('utf-8')) + } + return { + 'Authorization': 'Basic %s' % self.basic_credentials, + } + else: + self.__authorize_when_necessary() + return { + 'Authorization': 'Bearer %s' % self.token_credentials, + } + + def _timestamp(self): + """Returns the local time in milliseconds since the unix epoch""" + return int(time.time() * 1000) + + def _random_nonce(self): + return uuid.uuid4().hex[:16] + + def token_request_from_auth_url(self, method: str, url: str, token_params, + headers, auth_params): + body = None + params = None + if method == 'GET': + body = {} + params = dict(auth_params, **token_params) + elif method == 'POST': + if isinstance(auth_params, TokenDetails): + auth_params = auth_params.to_dict() + params = {} + body = dict(auth_params, **token_params) + + from ably.sync.http.http import Response + with httpx.Client(http2=True) as client: + resp = client.request(method=method, url=url, headers=headers, params=params, data=body) + response = Response(resp) + + AblyException.raise_for_response(response) + + content_type = response.response.headers.get('content-type') + + if not content_type: + raise AblyAuthException("auth_url response missing a content-type header", 401, 40170) + + is_json = "application/json" in content_type + is_text = "application/jwt" in content_type or "text/plain" in content_type + + if is_json: + token_request = response.to_native() + elif is_text: + token_request = response.text + else: + msg = 'auth_url responded with unacceptable content-type ' + content_type + \ + ', should be either text/plain, application/jwt or application/json', + raise AblyAuthException(msg, 401, 40170) + return token_request diff --git a/ably/sync/rest/channel.py b/ably/sync/rest/channel.py new file mode 100644 index 00000000..f1f3f199 --- /dev/null +++ b/ably/sync/rest/channel.py @@ -0,0 +1,229 @@ +import base64 +from collections import OrderedDict +import logging +import json +import os +from typing import Iterator +from urllib import parse + +from methoddispatch import SingleDispatch, singledispatch +import msgpack + +from ably.sync.http.paginatedresult import PaginatedResult, format_params +from ably.sync.types.channeldetails import ChannelDetails +from ably.sync.types.message import Message, make_message_response_handler +from ably.sync.types.presence import Presence +from ably.sync.util.crypto import get_cipher +from ably.sync.util.exceptions import catch_all, IncompatibleClientIdException + +log = logging.getLogger(__name__) + + +class Channel(SingleDispatch): + def __init__(self, ably, name, options): + self.__ably = ably + self.__name = name + self.__base_path = '/channels/%s/' % parse.quote_plus(name, safe=':') + self.__cipher = None + self.options = options + self.__presence = Presence(self) + + @catch_all + def history(self, direction=None, limit: int = None, start=None, end=None): + """Returns the history for this channel""" + params = format_params({}, direction=direction, start=start, end=end, limit=limit) + path = self.__base_path + 'messages' + params + + message_handler = make_message_response_handler(self.__cipher) + return PaginatedResult.paginated_query( + self.ably.http, url=path, response_processor=message_handler) + + def __publish_request_body(self, messages): + """ + Helper private method, separated from publish() to test RSL1j + """ + # Idempotent publishing + if self.ably.options.idempotent_rest_publishing: + # RSL1k1 + if all(message.id is None for message in messages): + base_id = base64.b64encode(os.urandom(12)).decode() + for serial, message in enumerate(messages): + message.id = '{}:{}'.format(base_id, serial) + + request_body_list = [] + for m in messages: + if m.client_id == '*': + raise IncompatibleClientIdException( + 'Wildcard client_id is reserved and cannot be used when publishing messages', + 400, 40012) + elif m.client_id is not None and not self.ably.auth.can_assume_client_id(m.client_id): + raise IncompatibleClientIdException( + 'Cannot publish with client_id \'{}\' as it is incompatible with the ' + 'current configured client_id \'{}\''.format(m.client_id, self.ably.auth.client_id), + 400, 40012) + + if self.cipher: + m.encrypt(self.__cipher) + + request_body_list.append(m) + + request_body = [ + message.as_dict(binary=self.ably.options.use_binary_protocol) + for message in request_body_list] + + if len(request_body) == 1: + request_body = request_body[0] + + return request_body + + @singledispatch + def _publish(self, arg, *args, **kwargs): + raise TypeError('Unexpected type %s' % type(arg)) + + @_publish.register(Message) + def publish_message(self, message, params=None, timeout=None): + return self.publish_messages([message], params, timeout=timeout) + + @_publish.register(list) + def publish_messages(self, messages, params=None, timeout=None): + request_body = self.__publish_request_body(messages) + if not self.ably.options.use_binary_protocol: + request_body = json.dumps(request_body, separators=(',', ':')) + else: + request_body = msgpack.packb(request_body, use_bin_type=True) + + path = self.__base_path + 'messages' + if params: + params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} + path += '?' + parse.urlencode(params) + return self.ably.http.post(path, body=request_body, timeout=timeout) + + @_publish.register(str) + def publish_name_data(self, name, data, timeout=None): + messages = [Message(name, data)] + return self.publish_messages(messages, timeout=timeout) + + def publish(self, *args, **kwargs): + """Publishes a message on this channel. + + :Parameters: + - `name`: the name for this message. + - `data`: the data for this message. + - `messages`: list of `Message` objects to be published. + - `message`: a single `Message` objet to be published + + :attention: You can publish using `name` and `data` OR `messages` OR + `message`, never all three. + """ + # For backwards compatibility + if len(args) == 0: + if len(kwargs) == 0: + return self.publish_name_data(None, None) + + if 'name' in kwargs or 'data' in kwargs: + name = kwargs.pop('name', None) + data = kwargs.pop('data', None) + return self.publish_name_data(name, data, **kwargs) + + if 'messages' in kwargs: + messages = kwargs.pop('messages') + return self.publish_messages(messages, **kwargs) + + return self._publish(*args, **kwargs) + + def status(self): + """Retrieves current channel active status with no. of publishers, subscribers, presence_members etc""" + + path = '/channels/%s' % self.name + response = self.ably.http.get(path) + obj = response.to_native() + return ChannelDetails.from_dict(obj) + + @property + def ably(self): + return self.__ably + + @property + def name(self): + return self.__name + + @property + def base_path(self): + return self.__base_path + + @property + def cipher(self): + return self.__cipher + + @property + def options(self): + return self.__options + + @property + def presence(self): + return self.__presence + + @options.setter + def options(self, options): + self.__options = options + + if options and 'cipher' in options: + cipher = options.get('cipher') + if cipher is not None: + cipher = get_cipher(cipher) + self.__cipher = cipher + + +class Channels: + def __init__(self, rest): + self.__ably = rest + self.__all: dict = OrderedDict() + + def get(self, name, **kwargs): + if isinstance(name, bytes): + name = name.decode('ascii') + + if name not in self.__all: + result = self.__all[name] = Channel(self.__ably, name, kwargs) + else: + result = self.__all[name] + if len(kwargs) != 0: + result.options = kwargs + + return result + + def __getitem__(self, key): + return self.get(key) + + def __getattr__(self, name): + return self.get(name) + + def __contains__(self, item): + if isinstance(item, Channel): + name = item.name + elif isinstance(item, bytes): + name = item.decode('ascii') + else: + name = item + + return name in self.__all + + def __iter__(self) -> Iterator[str]: + return iter(self.__all.values()) + + # RSN4 + def release(self, name: str): + """Releases a Channel object, deleting it, and enabling it to be garbage collected. + If the channel does not exist, nothing happens. + + It also removes any listeners associated with the channel. + + Parameters + ---------- + name: str + Channel name + """ + + if name not in self.__all: + return + del self.__all[name] diff --git a/ably/sync/rest/push.py b/ably/sync/rest/push.py new file mode 100644 index 00000000..fabb2c1a --- /dev/null +++ b/ably/sync/rest/push.py @@ -0,0 +1,189 @@ +from typing import Optional +from ably.sync.http.paginatedresult import PaginatedResult, format_params +from ably.sync.types.device import DeviceDetails, device_details_response_processor +from ably.sync.types.channelsubscription import PushChannelSubscription, channel_subscriptions_response_processor +from ably.sync.types.channelsubscription import channels_response_processor + + +class Push: + + def __init__(self, ably): + self.__ably = ably + self.__admin = PushAdmin(ably) + + @property + def admin(self): + return self.__admin + + +class PushAdmin: + + def __init__(self, ably): + self.__ably = ably + self.__device_registrations = PushDeviceRegistrations(ably) + self.__channel_subscriptions = PushChannelSubscriptions(ably) + + @property + def ably(self): + return self.__ably + + @property + def device_registrations(self): + return self.__device_registrations + + @property + def channel_subscriptions(self): + return self.__channel_subscriptions + + def publish(self, recipient: dict, data: dict, timeout: Optional[float] = None): + """Publish a push notification to a single device. + + :Parameters: + - `recipient`: the recipient of the notification + - `data`: the data of the notification + """ + if not isinstance(recipient, dict): + raise TypeError('Unexpected %s recipient, expected a dict' % type(recipient)) + + if not isinstance(data, dict): + raise TypeError('Unexpected %s data, expected a dict' % type(data)) + + if not recipient: + raise ValueError('recipient is empty') + + if not data: + raise ValueError('data is empty') + + body = data.copy() + body.update({'recipient': recipient}) + self.ably.http.post('/push/publish', body=body, timeout=timeout) + + +class PushDeviceRegistrations: + + def __init__(self, ably): + self.__ably = ably + + @property + def ably(self): + return self.__ably + + def get(self, device_id: str): + """Returns a DeviceDetails object if the device id is found or results + in a not found error if the device cannot be found. + + :Parameters: + - `device_id`: the id of the device + """ + path = '/push/deviceRegistrations/%s' % device_id + response = self.ably.http.get(path) + obj = response.to_native() + return DeviceDetails.from_dict(obj) + + def list(self, **params): + """Returns a PaginatedResult object with the list of DeviceDetails + objects, filtered by the given parameters. + + :Parameters: + - `**params`: the parameters used to filter the list + """ + path = '/push/deviceRegistrations' + format_params(params) + return PaginatedResult.paginated_query( + self.ably.http, url=path, + response_processor=device_details_response_processor) + + def save(self, device: dict): + """Creates or updates the device. Returns a DeviceDetails object. + + :Parameters: + - `device`: a dictionary with the device information + """ + device_details = DeviceDetails.factory(device) + path = '/push/deviceRegistrations/%s' % device_details.id + body = device_details.as_dict() + response = self.ably.http.put(path, body=body) + obj = response.to_native() + return DeviceDetails.from_dict(obj) + + def remove(self, device_id: str): + """Deletes the registered device identified by the given device id. + + :Parameters: + - `device_id`: the id of the device + """ + path = '/push/deviceRegistrations/%s' % device_id + return self.ably.http.delete(path) + + def remove_where(self, **params): + """Deletes the registered devices identified by the given parameters. + + :Parameters: + - `**params`: the parameters that identify the devices to remove + """ + path = '/push/deviceRegistrations' + format_params(params) + return self.ably.http.delete(path) + + +class PushChannelSubscriptions: + + def __init__(self, ably): + self.__ably = ably + + @property + def ably(self): + return self.__ably + + def list(self, **params): + """Returns a PaginatedResult object with the list of + PushChannelSubscription objects, filtered by the given parameters. + + :Parameters: + - `**params`: the parameters used to filter the list + """ + path = '/push/channelSubscriptions' + format_params(params) + return PaginatedResult.paginated_query(self.ably.http, url=path, + response_processor=channel_subscriptions_response_processor) + + def list_channels(self, **params): + """Returns a PaginatedResult object with the list of + PushChannelSubscription objects, filtered by the given parameters. + + :Parameters: + - `**params`: the parameters used to filter the list + """ + path = '/push/channels' + format_params(params) + return PaginatedResult.paginated_query(self.ably.http, url=path, + response_processor=channels_response_processor) + + def save(self, subscription: dict): + """Creates or updates the subscription. Returns a + PushChannelSubscription object. + + :Parameters: + - `subscription`: a dictionary with the subscription information + """ + subscription = PushChannelSubscription.factory(subscription) + path = '/push/channelSubscriptions' + body = subscription.as_dict() + response = self.ably.http.post(path, body=body) + obj = response.to_native() + return PushChannelSubscription.from_dict(obj) + + def remove(self, subscription: dict): + """Deletes the given subscription. + + :Parameters: + - `subscription`: the subscription object to remove + """ + subscription = PushChannelSubscription.factory(subscription) + params = subscription.as_dict() + return self.remove_where(**params) + + def remove_where(self, **params): + """Deletes the subscriptions identified by the given parameters. + + :Parameters: + - `**params`: the parameters that identify the subscriptions to remove + """ + path = '/push/channelSubscriptions' + format_params(**params) + return self.ably.http.delete(path) diff --git a/ably/sync/rest/rest.py b/ably/sync/rest/rest.py new file mode 100644 index 00000000..ff163967 --- /dev/null +++ b/ably/sync/rest/rest.py @@ -0,0 +1,148 @@ +import logging +from typing import Optional +from urllib.parse import urlencode + +from ably.sync.http.http import Http +from ably.sync.http.paginatedresult import PaginatedResult, HttpPaginatedResponse +from ably.sync.http.paginatedresult import format_params +from ably.sync.rest.auth import Auth +from ably.sync.rest.channel import Channels +from ably.sync.rest.push import Push +from ably.sync.util.exceptions import AblyException, catch_all +from ably.sync.types.options import Options +from ably.sync.types.stats import stats_response_processor +from ably.sync.types.tokendetails import TokenDetails + +log = logging.getLogger(__name__) + + +class AblyRest: + """Ably Rest Client""" + + def __init__(self, key: Optional[str] = None, token: Optional[str] = None, + token_details: Optional[TokenDetails] = None, **kwargs): + """Create an AblyRest instance. + + :Parameters: + **Credentials** + - `key`: a valid key string + + **Or** + - `token`: a valid token string + - `token_details`: an instance of TokenDetails class + + **Optional Parameters** + - `client_id`: Undocumented + - `rest_host`: The host to connect to. Defaults to rest.ably.io + - `environment`: The environment to use. Defaults to 'production' + - `port`: The port to connect to. Defaults to 80 + - `tls_port`: The tls_port to connect to. Defaults to 443 + - `tls`: Specifies whether the client should use TLS. Defaults + to True + - `auth_token`: Undocumented + - `auth_callback`: Undocumented + - `auth_url`: Undocumented + - `keep_alive`: use persistent connections. Defaults to True + """ + if key is not None and ('key_name' in kwargs or 'key_secret' in kwargs): + raise ValueError("key and key_name or key_secret are mutually exclusive. " + "Provider either a key or key_name & key_secret") + if key is not None: + options = Options(key=key, **kwargs) + elif token is not None: + options = Options(auth_token=token, **kwargs) + elif token_details is not None: + if not isinstance(token_details, TokenDetails): + raise ValueError("token_details must be an instance of TokenDetails") + options = Options(token_details=token_details, **kwargs) + elif not ('auth_callback' in kwargs or 'auth_url' in kwargs or + # and don't have both key_name and key_secret + ('key_name' in kwargs and 'key_secret' in kwargs)): + raise ValueError("key is missing. Either an API key, token, or token auth method must be provided") + else: + options = Options(**kwargs) + + try: + self._is_realtime + except AttributeError: + self._is_realtime = False + + self.__http = Http(self, options) + self.__auth = Auth(self, options) + self.__http.auth = self.__auth + + self.__channels = Channels(self) + self.__options = options + self.__push = Push(self) + + def __enter__(self): + return self + + @catch_all + def stats(self, direction: Optional[str] = None, start=None, end=None, params: Optional[dict] = None, + limit: Optional[int] = None, paginated=None, unit=None, timeout=None): + """Returns the stats for this application""" + formatted_params = format_params(params, direction=direction, start=start, end=end, limit=limit, unit=unit) + url = '/stats' + formatted_params + return PaginatedResult.paginated_query( + self.http, url=url, response_processor=stats_response_processor) + + @catch_all + def time(self, timeout: Optional[float] = None) -> float: + """Returns the current server time in ms since the unix epoch""" + r = self.http.get('/time', skip_auth=True, timeout=timeout) + AblyException.raise_for_response(r) + return r.to_native()[0] + + @property + def client_id(self) -> Optional[str]: + return self.options.client_id + + @property + def channels(self): + """Returns the channels container object""" + return self.__channels + + @property + def auth(self): + return self.__auth + + @property + def http(self): + return self.__http + + @property + def options(self): + return self.__options + + @property + def push(self): + return self.__push + + def request(self, method: str, path: str, version: str, params: + Optional[dict] = None, body=None, headers=None): + if version is None: + raise AblyException("No version parameter", 400, 40000) + + url = path + if params: + url += '?' + urlencode(params) + + def response_processor(response): + items = response.to_native() + if not items: + return [] + if type(items) is not list: + items = [items] + return items + + return HttpPaginatedResponse.paginated_query( + self.http, method, url, version=version, body=body, headers=headers, + response_processor=response_processor, + raise_on_error=False) + + def __exit__(self, *excinfo): + self.close() + + def close(self): + self.http.close() diff --git a/ably/sync/transport/__init__.py b/ably/sync/transport/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/transport/defaults.py b/ably/sync/transport/defaults.py new file mode 100644 index 00000000..7a732d9a --- /dev/null +++ b/ably/sync/transport/defaults.py @@ -0,0 +1,63 @@ +class Defaults: + protocol_version = "2" + fallback_hosts = [ + "a.ably-realtime.com", + "b.ably-realtime.com", + "c.ably-realtime.com", + "d.ably-realtime.com", + "e.ably-realtime.com", + ] + + rest_host = "rest.ably.io" + realtime_host = "realtime.ably.io" # RTN2 + connectivity_check_url = "https://internet-up.ably-realtime.com/is-the-internet-up.txt" + environment = 'production' + + port = 80 + tls_port = 443 + connect_timeout = 15000 + disconnect_timeout = 10000 + suspended_timeout = 60000 + comet_recv_timeout = 90000 + comet_send_timeout = 10000 + realtime_request_timeout = 10000 + channel_retry_timeout = 15000 + disconnected_retry_timeout = 15000 + connection_state_ttl = 120000 + suspended_retry_timeout = 30000 + + transports = [] # ["web_socket", "comet"] + + http_max_retry_count = 3 + + fallback_retry_timeout = 600000 # 10min + + @staticmethod + def get_port(options): + if options.tls: + if options.tls_port: + return options.tls_port + else: + return Defaults.tls_port + else: + if options.port: + return options.port + else: + return Defaults.port + + @staticmethod + def get_scheme(options): + if options.tls: + return "https" + else: + return "http" + + @staticmethod + def get_environment_fallback_hosts(environment): + return [ + environment + "-a-fallback.ably-realtime.com", + environment + "-b-fallback.ably-realtime.com", + environment + "-c-fallback.ably-realtime.com", + environment + "-d-fallback.ably-realtime.com", + environment + "-e-fallback.ably-realtime.com", + ] diff --git a/ably/sync/transport/websockettransport.py b/ably/sync/transport/websockettransport.py new file mode 100644 index 00000000..2de820d3 --- /dev/null +++ b/ably/sync/transport/websockettransport.py @@ -0,0 +1,219 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +import asyncio +from enum import IntEnum +import json +import logging +import socket +import urllib.parse +from ably.sync.http.httputils import HttpUtils +from ably.sync.types.connectiondetails import ConnectionDetails +from ably.sync.util.eventemitter import EventEmitter +from ably.sync.util.exceptions import AblyException +from ably.sync.util.helper import Timer, unix_time_ms +from websockets.client import WebSocketClientProtocol, connect as ws_connect +from websockets.exceptions import ConnectionClosedOK, WebSocketException + +if TYPE_CHECKING: + from ably.sync.realtime.connection import ConnectionManager + +log = logging.getLogger(__name__) + + +class ProtocolMessageAction(IntEnum): + HEARTBEAT = 0 + CONNECTED = 4 + DISCONNECTED = 6 + CLOSE = 7 + CLOSED = 8 + ERROR = 9 + ATTACH = 10 + ATTACHED = 11 + DETACH = 12 + DETACHED = 13 + MESSAGE = 15 + AUTH = 17 + + +class WebSocketTransport(EventEmitter): + def __init__(self, connection_manager: ConnectionManager, host: str, params: dict): + self.websocket: WebSocketClientProtocol | None = None + self.read_loop: asyncio.Task | None = None + self.connect_task: asyncio.Task | None = None + self.ws_connect_task: asyncio.Task | None = None + self.connection_manager = connection_manager + self.options = self.connection_manager.options + self.is_connected = False + self.idle_timer = None + self.last_activity = None + self.max_idle_interval = None + self.is_disposed = False + self.host = host + self.params = params + super().__init__() + + def connect(self): + headers = HttpUtils.default_headers() + query_params = urllib.parse.urlencode(self.params) + ws_url = (f'wss://{self.host}?{query_params}') + log.info(f'connect(): attempting to connect to {ws_url}') + self.ws_connect_task = asyncio.create_task(self.ws_connect(ws_url, headers)) + self.ws_connect_task.add_done_callback(self.on_ws_connect_done) + + def on_ws_connect_done(self, task: asyncio.Task): + try: + exception = task.exception() + except asyncio.CancelledError as e: + exception = e + if exception is None or isinstance(exception, ConnectionClosedOK): + return + log.info( + f'WebSocketTransport.on_ws_connect_done(): exception = {exception}' + ) + + def ws_connect(self, ws_url, headers): + try: + with ws_connect(ws_url, extra_headers=headers) as websocket: + log.info(f'ws_connect(): connection established to {ws_url}') + self._emit('connected') + self.websocket = websocket + self.read_loop = self.connection_manager.options.loop.create_task(self.ws_read_loop()) + self.read_loop.add_done_callback(self.on_read_loop_done) + try: + self.read_loop + except WebSocketException as err: + if not self.is_disposed: + self.dispose() + self.connection_manager.deactivate_transport(err) + except (WebSocketException, socket.gaierror) as e: + exception = AblyException(f'Error opening websocket connection: {e}', 400, 40000) + log.exception(f'WebSocketTransport.ws_connect(): Error opening websocket connection: {exception}') + self._emit('failed', exception) + raise exception + + def on_protocol_message(self, msg): + self.on_activity() + log.debug(f'WebSocketTransport.on_protocol_message(): received protocol message: {msg}') + action = msg.get('action') + if action == ProtocolMessageAction.CONNECTED: + connection_id = msg.get('connectionId') + connection_details = ConnectionDetails.from_dict(msg.get('connectionDetails')) + + error = msg.get('error') + exception = None + if error: + exception = AblyException.from_dict(error) + + max_idle_interval = connection_details.max_idle_interval + if max_idle_interval: + self.max_idle_interval = max_idle_interval + self.options.realtime_request_timeout + self.on_activity() + self.is_connected = True + if self.host != self.options.get_realtime_host(): # RTN17e + self.options.fallback_realtime_host = self.host + self.connection_manager.on_connected(connection_details, connection_id, reason=exception) + elif action == ProtocolMessageAction.DISCONNECTED: + error = msg.get('error') + exception = None + if error is not None: + exception = AblyException.from_dict(error) + self.connection_manager.on_disconnected(exception) + elif action == ProtocolMessageAction.AUTH: + try: + self.connection_manager.ably.auth.authorize() + except Exception as exc: + log.exception(f"WebSocketTransport.on_protocol_message(): An exception \ + occurred during reauth: {exc}") + elif action == ProtocolMessageAction.CLOSED: + if self.ws_connect_task: + self.ws_connect_task.cancel() + self.connection_manager.on_closed() + elif action == ProtocolMessageAction.ERROR: + error = msg.get('error') + exception = AblyException.from_dict(error) + self.connection_manager.on_error(msg, exception) + elif action == ProtocolMessageAction.HEARTBEAT: + id = msg.get('id') + self.connection_manager.on_heartbeat(id) + elif action in ( + ProtocolMessageAction.ATTACHED, + ProtocolMessageAction.DETACHED, + ProtocolMessageAction.MESSAGE + ): + self.connection_manager.on_channel_message(msg) + + def ws_read_loop(self): + if not self.websocket: + raise AblyException('ws_read_loop started with no websocket', 500, 50000) + try: + for raw in self.websocket: + msg = json.loads(raw) + task = asyncio.create_task(self.on_protocol_message(msg)) + task.add_done_callback(self.on_protcol_message_handled) + except ConnectionClosedOK: + return + + def on_protcol_message_handled(self, task): + try: + exception = task.exception() + except Exception as e: + exception = e + if exception is not None: + log.exception(f"WebSocketTransport.on_protocol_message_handled(): uncaught exception: {exception}") + + def on_read_loop_done(self, task: asyncio.Task): + try: + exception = task.exception() + except asyncio.CancelledError as e: + exception = e + if isinstance(exception, ConnectionClosedOK): + return + + def dispose(self): + self.is_disposed = True + if self.read_loop: + self.read_loop.cancel() + if self.ws_connect_task: + self.ws_connect_task.cancel() + if self.idle_timer: + self.idle_timer.cancel() + if self.websocket: + try: + self.websocket.close() + except asyncio.CancelledError: + return + + def close(self): + self.send({'action': ProtocolMessageAction.CLOSE}) + + def send(self, message: dict): + if self.websocket is None: + raise Exception() + raw_msg = json.dumps(message) + log.info(f'WebSocketTransport.send(): sending {raw_msg}') + self.websocket.send(raw_msg) + + def set_idle_timer(self, timeout: float): + if not self.idle_timer: + self.idle_timer = Timer(timeout, self.on_idle_timer_expire) + + def on_idle_timer_expire(self): + self.idle_timer = None + since_last = unix_time_ms() - self.last_activity + time_remaining = self.max_idle_interval - since_last + msg = f"No activity seen from realtime in {since_last} ms; assuming connection has dropped" + if time_remaining <= 0: + log.error(msg) + self.disconnect(AblyException(msg, 408, 80003)) + else: + self.set_idle_timer(time_remaining + 100) + + def on_activity(self): + if not self.max_idle_interval: + return + self.last_activity = unix_time_ms() + self.set_idle_timer(self.max_idle_interval + 100) + + def disconnect(self, reason=None): + self.dispose() + self.connection_manager.deactivate_transport(reason) diff --git a/ably/sync/types/__init__.py b/ably/sync/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/types/authoptions.py b/ably/sync/types/authoptions.py new file mode 100644 index 00000000..77178f47 --- /dev/null +++ b/ably/sync/types/authoptions.py @@ -0,0 +1,157 @@ +from ably.sync.util.exceptions import AblyException + + +class AuthOptions: + def __init__(self, auth_callback=None, auth_url=None, auth_method='GET', + auth_token=None, auth_headers=None, auth_params=None, + key_name=None, key_secret=None, key=None, query_time=False, + token_details=None, use_token_auth=None, + default_token_params=None): + self.__auth_options = {} + self.auth_options['auth_callback'] = auth_callback + self.auth_options['auth_url'] = auth_url + self.auth_options['auth_method'] = auth_method + self.auth_options['auth_headers'] = auth_headers + self.auth_options['auth_params'] = auth_params + self.auth_options['query_time'] = query_time + self.auth_options['key_name'] = key_name + self.auth_options['key_secret'] = key_secret + self.set_key(key) + + self.__auth_token = auth_token + self.__token_details = token_details + self.__use_token_auth = use_token_auth + default_token_params = default_token_params or {} + default_token_params.pop('timestamp', None) + self.default_token_params = default_token_params + + def set_key(self, key): + if key is None: + return + + try: + key_name, key_secret = key.split(':') + self.auth_options['key_name'] = key_name + self.auth_options['key_secret'] = key_secret + except ValueError: + raise AblyException("key of not len 2 parameters: {0}" + .format(key.split(':')), + 401, 40101) + + def replace(self, auth_options): + if type(auth_options) is dict: + auth_options = dict(auth_options) + key = auth_options.pop('key', None) + self.auth_options = auth_options + self.set_key(key) + elif type(auth_options) is AuthOptions: + self.auth_options = dict(auth_options.auth_options) + else: + raise KeyError('Expected dict or AuthOptions') + + @property + def auth_options(self): + return self.__auth_options + + @auth_options.setter + def auth_options(self, value): + self.__auth_options = value + + @property + def auth_callback(self): + return self.auth_options['auth_callback'] + + @auth_callback.setter + def auth_callback(self, value): + self.auth_options['auth_callback'] = value + + @property + def auth_url(self): + return self.auth_options['auth_url'] + + @auth_url.setter + def auth_url(self, value): + self.auth_options['auth_url'] = value + + @property + def auth_method(self): + return self.auth_options['auth_method'] + + @auth_method.setter + def auth_method(self, value): + self.auth_options['auth_method'] = value.upper() + + @property + def key_name(self): + return self.auth_options['key_name'] + + @key_name.setter + def key_name(self, value): + self.auth_options['key_name'] = value + + @property + def key_secret(self): + return self.auth_options['key_secret'] + + @key_secret.setter + def key_secret(self, value): + self.auth_options['key_secret'] = value + + @property + def auth_token(self): + return self.__auth_token + + @auth_token.setter + def auth_token(self, value): + self.__auth_token = value + + @property + def auth_headers(self): + return self.auth_options['auth_headers'] + + @auth_headers.setter + def auth_headers(self, value): + self.auth_options['auth_headers'] = value + + @property + def auth_params(self): + return self.auth_options['auth_params'] + + @auth_params.setter + def auth_params(self, value): + self.auth_options['auth_params'] = value + + @property + def query_time(self): + return self.auth_options['query_time'] + + @query_time.setter + def query_time(self, value): + self.auth_options['query_time'] = value + + @property + def token_details(self): + return self.__token_details + + @token_details.setter + def token_details(self, value): + self.__token_details = value + + @property + def use_token_auth(self): + return self.__use_token_auth + + @use_token_auth.setter + def use_token_auth(self, value): + self.__use_token_auth = value + + @property + def default_token_params(self): + return self.__default_token_params + + @default_token_params.setter + def default_token_params(self, value): + self.__default_token_params = value + + def __str__(self): + return str(self.__dict__) diff --git a/ably/sync/types/capability.py b/ably/sync/types/capability.py new file mode 100644 index 00000000..5d209d7c --- /dev/null +++ b/ably/sync/types/capability.py @@ -0,0 +1,82 @@ +from collections.abc import MutableMapping +import json +import logging + + +log = logging.getLogger(__name__) + + +class Capability(MutableMapping): + def __init__(self, obj=None): + if obj is None: + obj = {} + self.__dict = dict(obj) + for k, v in obj.items(): + self[k] = v + + def __eq__(self, other): + if isinstance(other, Capability): + return Capability.c14n(self) == Capability.c14n(other) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Capability): + return Capability.c14n(self) != Capability.c14n(other) + return NotImplemented + + def __getitem__(self, key): + return self.__dict[key] + + def __iter__(self): + return iter(self.__dict) + + def __len__(self): + return len(self.__dict) + + def __contains__(self, key): + return key in self.__dict + + def __setitem__(self, key, value): + # validate that the value is a list of ops and that the key is a string + if not isinstance(key, str): + raise ValueError('Capability keys must be strings') + + if isinstance(value, str): + value = [value] + + operations = set() + for val in iter(value): + if not isinstance(val, str): + raise ValueError('Operations must be strings') + operations.add(val) + + self.__dict[key] = operations + + def __delitem__(self, key): + del self.__dict[key] + + def setdefault(self, key, default): + if key not in self: + self[key] = default + return self[key] + + def add_resource(self, resource, operations=None): + if operations is None: + operations = [] + if isinstance(operations, str): + operations = [operations] + self[resource] = list(operations) + + def add_operation_to_resource(self, operation, resource): + self.setdefault(resource, []).append(operation) + + def __str__(self): + return Capability.c14n(self) + + def to_dict(self): + return {k: sorted(v) for k, v in self.items()} + + @staticmethod + def c14n(capability): + sorted_ops = capability.to_dict() + return json.dumps(sorted_ops, sort_keys=True) diff --git a/ably/sync/types/channeldetails.py b/ably/sync/types/channeldetails.py new file mode 100644 index 00000000..d959d487 --- /dev/null +++ b/ably/sync/types/channeldetails.py @@ -0,0 +1,116 @@ +from __future__ import annotations + + +class ChannelDetails: + + def __init__(self, channel_id, status): + self.__channel_id = channel_id + self.__status = status + + @property + def channel_id(self) -> str: + return self.__channel_id + + @property + def status(self) -> ChannelStatus: + return self.__status + + @staticmethod + def from_dict(obj): + kwargs = { + 'channel_id': obj.get("channelId"), + 'status': ChannelStatus.from_dict(obj.get("status")) + } + + return ChannelDetails(**kwargs) + + +class ChannelStatus: + + def __init__(self, is_active, occupancy): + self.__is_active = is_active + self.__occupancy = occupancy + + @property + def is_active(self) -> bool: + return self.__is_active + + @property + def occupancy(self) -> ChannelOccupancy: + return self.__occupancy + + @staticmethod + def from_dict(obj): + kwargs = { + 'is_active': obj.get("isActive"), + 'occupancy': ChannelOccupancy.from_dict(obj.get("occupancy")) + } + + return ChannelStatus(**kwargs) + + +class ChannelOccupancy: + + def __init__(self, metrics): + self.__metrics = metrics + + @property + def metrics(self) -> ChannelMetrics: + return self.__metrics + + @staticmethod + def from_dict(obj): + kwargs = { + 'metrics': ChannelMetrics.from_dict(obj.get("metrics")) + } + + return ChannelOccupancy(**kwargs) + + +class ChannelMetrics: + + def __init__(self, connections, presence_connections, presence_members, + presence_subscribers, publishers, subscribers): + self.__connections = connections + self.__presence_connections = presence_connections + self.__presence_members = presence_members + self.__presence_subscribers = presence_subscribers + self.__publishers = publishers + self.__subscribers = subscribers + + @property + def connections(self) -> int: + return self.__connections + + @property + def presence_connections(self) -> int: + return self.__presence_connections + + @property + def presence_members(self) -> int: + return self.__presence_members + + @property + def presence_subscribers(self) -> int: + return self.__presence_subscribers + + @property + def publishers(self) -> int: + return self.__publishers + + @property + def subscribers(self) -> int: + return self.__subscribers + + @staticmethod + def from_dict(obj): + kwargs = { + 'connections': obj.get("connections"), + 'presence_connections': obj.get("presenceConnections"), + 'presence_members': obj.get("presenceMembers"), + 'presence_subscribers': obj.get("presenceSubscribers"), + 'publishers': obj.get("publishers"), + 'subscribers': obj.get("subscribers") + } + + return ChannelMetrics(**kwargs) diff --git a/ably/sync/types/channelstate.py b/ably/sync/types/channelstate.py new file mode 100644 index 00000000..83352f7b --- /dev/null +++ b/ably/sync/types/channelstate.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional +from enum import Enum +from ably.sync.util.exceptions import AblyException + + +class ChannelState(str, Enum): + INITIALIZED = 'initialized' + ATTACHING = 'attaching' + ATTACHED = 'attached' + DETACHING = 'detaching' + DETACHED = 'detached' + SUSPENDED = 'suspended' + FAILED = 'failed' + + +@dataclass +class ChannelStateChange: + previous: ChannelState + current: ChannelState + resumed: bool + reason: Optional[AblyException] = None diff --git a/ably/sync/types/channelsubscription.py b/ably/sync/types/channelsubscription.py new file mode 100644 index 00000000..fec042ad --- /dev/null +++ b/ably/sync/types/channelsubscription.py @@ -0,0 +1,70 @@ +from ably.sync.util import case + + +class PushChannelSubscription: + + def __init__(self, channel, device_id=None, client_id=None, app_id=None): + if not device_id and not client_id: + raise ValueError('missing expected device or client id') + + if device_id and client_id: + raise ValueError('both device and client id given, only one expected') + + self.__channel = channel + self.__device_id = device_id + self.__client_id = client_id + self.__app_id = app_id + + @property + def channel(self): + return self.__channel + + @property + def device_id(self): + return self.__device_id + + @property + def client_id(self): + return self.__client_id + + @property + def app_id(self): + return self.__app_id + + def as_dict(self): + keys = ['channel', 'device_id', 'client_id', 'app_id'] + + obj = {} + for key in keys: + value = getattr(self, key) + if value is not None: + key = case.snake_to_camel(key) + obj[key] = value + + return obj + + @classmethod + def from_dict(cls, obj): + obj = {case.camel_to_snake(key): value for key, value in obj.items()} + return cls(**obj) + + @classmethod + def from_array(cls, array): + return [cls.from_dict(d) for d in array] + + @classmethod + def factory(cls, subscription): + if isinstance(subscription, cls): + return subscription + + return cls.from_dict(subscription) + + +def channel_subscriptions_response_processor(response): + native = response.to_native() + return PushChannelSubscription.from_array(native) + + +def channels_response_processor(response): + native = response.to_native() + return native diff --git a/ably/sync/types/connectiondetails.py b/ably/sync/types/connectiondetails.py new file mode 100644 index 00000000..a281daed --- /dev/null +++ b/ably/sync/types/connectiondetails.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + + +@dataclass() +class ConnectionDetails: + connection_state_ttl: int + max_idle_interval: int + connection_key: str + + def __init__(self, connection_state_ttl: int, max_idle_interval: int, + connection_key: str, client_id: str): + self.connection_state_ttl = connection_state_ttl + self.max_idle_interval = max_idle_interval + self.connection_key = connection_key + self.client_id = client_id + + @staticmethod + def from_dict(json_dict: dict): + return ConnectionDetails(json_dict.get('connectionStateTtl'), json_dict.get('maxIdleInterval'), + json_dict.get('connectionKey'), json_dict.get('clientId')) diff --git a/ably/sync/types/connectionerrors.py b/ably/sync/types/connectionerrors.py new file mode 100644 index 00000000..e63ddea9 --- /dev/null +++ b/ably/sync/types/connectionerrors.py @@ -0,0 +1,30 @@ +from ably.sync.types.connectionstate import ConnectionState +from ably.sync.util.exceptions import AblyException + +ConnectionErrors = { + ConnectionState.DISCONNECTED: AblyException( + 'Connection to server temporarily unavailable', + 400, + 80003, + ), + ConnectionState.SUSPENDED: AblyException( + 'Connection to server unavailable', + 400, + 80002, + ), + ConnectionState.FAILED: AblyException( + 'Connection failed or disconnected by server', + 400, + 80000, + ), + ConnectionState.CLOSING: AblyException( + 'Connection closing', + 400, + 80017, + ), + ConnectionState.CLOSED: AblyException( + 'Connection closed', + 400, + 80017, + ), +} diff --git a/ably/sync/types/connectionstate.py b/ably/sync/types/connectionstate.py new file mode 100644 index 00000000..24747466 --- /dev/null +++ b/ably/sync/types/connectionstate.py @@ -0,0 +1,36 @@ +from enum import Enum +from dataclasses import dataclass +from typing import Optional + +from ably.sync.util.exceptions import AblyException + + +class ConnectionState(str, Enum): + INITIALIZED = 'initialized' + CONNECTING = 'connecting' + CONNECTED = 'connected' + DISCONNECTED = 'disconnected' + CLOSING = 'closing' + CLOSED = 'closed' + FAILED = 'failed' + SUSPENDED = 'suspended' + + +class ConnectionEvent(str, Enum): + INITIALIZED = 'initialized' + CONNECTING = 'connecting' + CONNECTED = 'connected' + DISCONNECTED = 'disconnected' + CLOSING = 'closing' + CLOSED = 'closed' + FAILED = 'failed' + SUSPENDED = 'suspended' + UPDATE = 'update' + + +@dataclass +class ConnectionStateChange: + previous: ConnectionState + current: ConnectionState + event: ConnectionEvent + reason: Optional[AblyException] = None # RTN4f diff --git a/ably/sync/types/device.py b/ably/sync/types/device.py new file mode 100644 index 00000000..5cfefa5c --- /dev/null +++ b/ably/sync/types/device.py @@ -0,0 +1,116 @@ +from ably.sync.util import case + + +DevicePushTransportType = {'fcm', 'gcm', 'apns', 'web'} +DevicePlatform = {'android', 'ios', 'browser'} +DeviceFormFactor = {'phone', 'tablet', 'desktop', 'tv', 'watch', 'car', 'embedded', 'other'} + + +class DeviceDetails: + + def __init__(self, id, client_id=None, form_factor=None, metadata=None, + platform=None, push=None, update_token=None, app_id=None, + device_identity_token=None, modified=None, device_secret=None): + + if push: + recipient = push.get('recipient') + if recipient: + transport_type = recipient.get('transportType') + if transport_type is not None and transport_type not in DevicePushTransportType: + raise ValueError('unexpected transport type {}'.format(transport_type)) + + if platform is not None and platform not in DevicePlatform: + raise ValueError('unexpected platform {}'.format(platform)) + + if form_factor is not None and form_factor not in DeviceFormFactor: + raise ValueError('unexpected form factor {}'.format(form_factor)) + + self.__id = id + self.__client_id = client_id + self.__form_factor = form_factor + self.__metadata = metadata + self.__platform = platform + self.__push = push + self.__update_token = update_token + self.__app_id = app_id + self.__device_identity_token = device_identity_token + self.__modified = modified + self.__device_secret = device_secret + + @property + def id(self): + return self.__id + + @property + def client_id(self): + return self.__client_id + + @property + def form_factor(self): + return self.__form_factor + + @property + def metadata(self): + return self.__metadata + + @property + def platform(self): + return self.__platform + + @property + def push(self): + return self.__push + + @property + def update_token(self): + return self.__update_token + + @property + def app_id(self): + return self.__app_id + + @property + def device_identity_token(self): + return self.__device_identity_token + + @property + def modified(self): + return self.__modified + + @property + def device_secret(self): + return self.__device_secret + + def as_dict(self): + keys = ['id', 'client_id', 'form_factor', 'metadata', 'platform', + 'push', 'update_token', 'app_id', 'device_identity_token', 'modified', 'device_secret'] + + obj = {} + for key in keys: + value = getattr(self, key) + if value is not None: + key = case.snake_to_camel(key) + obj[key] = value + + return obj + + @classmethod + def from_dict(cls, obj): + obj = {case.camel_to_snake(key): value for key, value in obj.items()} + return cls(**obj) + + @classmethod + def from_array(cls, array): + return [cls.from_dict(d) for d in array] + + @classmethod + def factory(cls, device): + if isinstance(device, cls): + return device + + return cls.from_dict(device) + + +def device_details_response_processor(response): + native = response.to_native() + return DeviceDetails.from_array(native) diff --git a/ably/sync/types/flags.py b/ably/sync/types/flags.py new file mode 100644 index 00000000..1666434c --- /dev/null +++ b/ably/sync/types/flags.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class Flag(int, Enum): + # Channel attach state flags + HAS_PRESENCE = 1 << 0 + HAS_BACKLOG = 1 << 1 + RESUMED = 1 << 2 + TRANSIENT = 1 << 4 + ATTACH_RESUME = 1 << 5 + # Channel mode flags + PRESENCE = 1 << 16 + PUBLISH = 1 << 17 + SUBSCRIBE = 1 << 18 + PRESENCE_SUBSCRIBE = 1 << 19 + + +def has_flag(message_flags: int, flag: Flag): + return message_flags & flag > 0 diff --git a/ably/sync/types/message.py b/ably/sync/types/message.py new file mode 100644 index 00000000..43c0a03c --- /dev/null +++ b/ably/sync/types/message.py @@ -0,0 +1,233 @@ +import base64 +import json +import logging + +from ably.sync.types.typedbuffer import TypedBuffer +from ably.sync.types.mixins import EncodeDataMixin +from ably.sync.util.crypto import CipherData +from ably.sync.util.exceptions import AblyException + +log = logging.getLogger(__name__) + + +def to_text(value): + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, bytes): + return value.decode() + else: + raise TypeError("expected string or bytes, not %s" % type(value)) + + +class Message(EncodeDataMixin): + + def __init__(self, + name=None, # TM2g + data=None, # TM2d + client_id=None, # TM2b + id=None, # TM2a + connection_id=None, # TM2c + connection_key=None, # TM2h + encoding='', # TM2e + timestamp=None, # TM2f + extras=None, # TM2i + ): + + super().__init__(encoding) + + self.__name = to_text(name) + self.__data = data + self.__client_id = to_text(client_id) + self.__id = to_text(id) + self.__connection_id = connection_id + self.__connection_key = connection_key + self.__timestamp = timestamp + self.__extras = extras + + def __eq__(self, other): + if isinstance(other, Message): + return (self.name == other.name + and self.data == other.data + and self.client_id == other.client_id + and self.timestamp == other.timestamp) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Message): + result = self.__eq__(other) + if result != NotImplemented: + return not result + return NotImplemented + + @property + def name(self): + return self.__name + + @property + def data(self): + return self.__data + + @property + def client_id(self): + return self.__client_id + + @property + def id(self): + return self.__id + + @id.setter + def id(self, value): + self.__id = value + + @property + def connection_id(self): + return self.__connection_id + + @property + def connection_key(self): + return self.__connection_key + + @property + def timestamp(self): + return self.__timestamp + + @property + def extras(self): + return self.__extras + + def encrypt(self, channel_cipher): + if isinstance(self.data, CipherData): + return + + elif isinstance(self.data, str): + self._encoding_array.append('utf-8') + + if isinstance(self.data, dict) or isinstance(self.data, list): + self._encoding_array.append('json') + self._encoding_array.append('utf-8') + + typed_data = TypedBuffer.from_obj(self.data) + if typed_data.buffer is None: + return True + encrypted_data = channel_cipher.encrypt(typed_data.buffer) + self.__data = CipherData(encrypted_data, typed_data.type, + cipher_type=channel_cipher.cipher_type) + + @staticmethod + def decrypt_data(channel_cipher, data): + if not isinstance(data, CipherData): + return + decrypted_data = channel_cipher.decrypt(data.buffer) + decrypted_typed_buffer = TypedBuffer(decrypted_data, data.type) + + return decrypted_typed_buffer.decode() + + def decrypt(self, channel_cipher): + decrypted_data = self.decrypt_data(channel_cipher, self.__data) + if decrypted_data is not None: + self.__data = decrypted_data + + def as_dict(self, binary=False): + data = self.data + data_type = None + encoding = self._encoding_array[:] + + if isinstance(data, (dict, list)): + encoding.append('json') + data = json.dumps(data) + data = str(data) + elif isinstance(data, str) and not binary: + pass + elif not binary and isinstance(data, (bytearray, bytes)): + data = base64.b64encode(data).decode('ascii') + encoding.append('base64') + elif isinstance(data, CipherData): + encoding.append(data.encoding_str) + data_type = data.type + if not binary: + data = base64.b64encode(data.buffer).decode('ascii') + encoding.append('base64') + else: + data = data.buffer + elif binary and isinstance(data, bytearray): + data = bytes(data) + + if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): + raise AblyException("Invalid data payload", 400, 40011) + + request_body = { + 'name': self.name, + 'data': data, + 'timestamp': self.timestamp or None, + 'type': data_type or None, + 'clientId': self.client_id or None, + 'id': self.id or None, + 'connectionId': self.connection_id or None, + 'connectionKey': self.connection_key or None, + 'extras': self.extras, + } + + if encoding: + request_body['encoding'] = '/'.join(encoding).strip('/') + + # None values aren't included + request_body = {k: v for k, v in request_body.items() if v is not None} + + return request_body + + @staticmethod + def from_encoded(obj, cipher=None): + id = obj.get('id') + name = obj.get('name') + data = obj.get('data') + client_id = obj.get('clientId') + connection_id = obj.get('connectionId') + timestamp = obj.get('timestamp') + encoding = obj.get('encoding', '') + extras = obj.get('extras', None) + + decoded_data = Message.decode(data, encoding, cipher) + + return Message( + id=id, + name=name, + connection_id=connection_id, + client_id=client_id, + timestamp=timestamp, + extras=extras, + **decoded_data + ) + + @staticmethod + def __update_empty_fields(proto_msg: dict, msg: dict, msg_index: int): + if msg.get("id") is None or msg.get("id") == '': + msg['id'] = f"{proto_msg.get('id')}:{msg_index}" + if msg.get("connectionId") is None or msg.get("connectionId") == '': + msg['connectionId'] = proto_msg.get('connectionId') + if msg.get("timestamp") is None or msg.get("timestamp") == 0: + msg['timestamp'] = proto_msg.get('timestamp') + + @staticmethod + def update_inner_message_fields(proto_msg: dict): + messages: list[dict] = proto_msg.get('messages') + presence_messages: list[dict] = proto_msg.get('presence') + if messages is not None: + msg_index = 0 + for msg in messages: + Message.__update_empty_fields(proto_msg, msg, msg_index) + msg_index = msg_index + 1 + + if presence_messages is not None: + msg_index = 0 + for presence_msg in presence_messages: + Message.__update_empty_fields(proto_msg, presence_msg, msg_index) + msg_index = msg_index + 1 + + +def make_message_response_handler(cipher): + def encrypted_message_response_handler(response): + messages = response.to_native() + return Message.from_encoded_array(messages, cipher=cipher) + return encrypted_message_response_handler diff --git a/ably/sync/types/mixins.py b/ably/sync/types/mixins.py new file mode 100644 index 00000000..d228611b --- /dev/null +++ b/ably/sync/types/mixins.py @@ -0,0 +1,75 @@ +import base64 +import json +import logging + +from ably.sync.util.crypto import CipherData + + +log = logging.getLogger(__name__) + + +class EncodeDataMixin: + + def __init__(self, encoding): + self.encoding = encoding + + @property + def encoding(self): + return '/'.join(self._encoding_array).strip('/') + + @encoding.setter + def encoding(self, encoding): + if not encoding: + self._encoding_array = [] + else: + self._encoding_array = encoding.strip('/').split('/') + + @staticmethod + def decode(data, encoding='', cipher=None): + encoding = encoding.strip('/') + encoding_list = encoding.split('/') + + while encoding_list: + encoding = encoding_list.pop() + if not encoding: + # With messagepack, binary data is sent as bytes, without need + # to specify the base64 encoding. Here we coerce to bytearray, + # since that's what is used with the Json transport; though it + # can be argued that it should be the other way, and use always + # bytes, never bytearray. + if type(data) is bytes: + data = bytearray(data) + continue + if encoding == 'json': + if isinstance(data, bytes): + data = data.decode() + if isinstance(data, list) or isinstance(data, dict): + continue + data = json.loads(data) + elif encoding == 'base64' and isinstance(data, bytes): + data = bytearray(base64.b64decode(data)) + elif encoding == 'base64': + data = bytearray(base64.b64decode(data.encode('utf-8'))) + elif encoding.startswith('%s+' % CipherData.ENCODING_ID): + if not cipher: + log.error('Message cannot be decrypted as the channel is ' + 'not set up for encryption & decryption') + encoding_list.append(encoding) + break + data = cipher.decrypt(data) + elif encoding == 'utf-8' and isinstance(data, (bytes, bytearray)): + data = data.decode('utf-8') + elif encoding == 'utf-8': + pass + else: + log.error('Message cannot be decoded. ' + "Unsupported encoding type: '%s'" % encoding) + encoding_list.append(encoding) + break + + encoding = '/'.join(encoding_list) + return {'encoding': encoding, 'data': data} + + @classmethod + def from_encoded_array(cls, objs, cipher=None): + return [cls.from_encoded(obj, cipher=cipher) for obj in objs] diff --git a/ably/sync/types/options.py b/ably/sync/types/options.py new file mode 100644 index 00000000..fb2dae2a --- /dev/null +++ b/ably/sync/types/options.py @@ -0,0 +1,330 @@ +import random +import logging + +from ably.sync.transport.defaults import Defaults +from ably.sync.types.authoptions import AuthOptions + +log = logging.getLogger(__name__) + + +class Options(AuthOptions): + def __init__(self, client_id=None, log_level=0, tls=True, rest_host=None, realtime_host=None, port=0, + tls_port=0, use_binary_protocol=True, queue_messages=False, recover=False, environment=None, + http_open_timeout=None, http_request_timeout=None, realtime_request_timeout=None, + http_max_retry_count=None, http_max_retry_duration=None, fallback_hosts=None, + fallback_retry_timeout=None, disconnected_retry_timeout=None, idempotent_rest_publishing=None, + loop=None, auto_connect=True, suspended_retry_timeout=None, connectivity_check_url=None, + channel_retry_timeout=Defaults.channel_retry_timeout, add_request_ids=False, **kwargs): + + super().__init__(**kwargs) + + # TODO check these defaults + if fallback_retry_timeout is None: + fallback_retry_timeout = Defaults.fallback_retry_timeout + + if realtime_request_timeout is None: + realtime_request_timeout = Defaults.realtime_request_timeout + + if disconnected_retry_timeout is None: + disconnected_retry_timeout = Defaults.disconnected_retry_timeout + + if connectivity_check_url is None: + connectivity_check_url = Defaults.connectivity_check_url + + connection_state_ttl = Defaults.connection_state_ttl + + if suspended_retry_timeout is None: + suspended_retry_timeout = Defaults.suspended_retry_timeout + + if environment is not None and rest_host is not None: + raise ValueError('specify rest_host or environment, not both') + + if environment is not None and realtime_host is not None: + raise ValueError('specify realtime_host or environment, not both') + + if idempotent_rest_publishing is None: + from ably.sync import api_version + idempotent_rest_publishing = api_version >= '1.2' + + if environment is None: + environment = Defaults.environment + + self.__client_id = client_id + self.__log_level = log_level + self.__tls = tls + self.__rest_host = rest_host + self.__realtime_host = realtime_host + self.__port = port + self.__tls_port = tls_port + self.__use_binary_protocol = use_binary_protocol + self.__queue_messages = queue_messages + self.__recover = recover + self.__environment = environment + self.__http_open_timeout = http_open_timeout + self.__http_request_timeout = http_request_timeout + self.__realtime_request_timeout = realtime_request_timeout + self.__http_max_retry_count = http_max_retry_count + self.__http_max_retry_duration = http_max_retry_duration + self.__fallback_hosts = fallback_hosts + self.__fallback_retry_timeout = fallback_retry_timeout + self.__disconnected_retry_timeout = disconnected_retry_timeout + self.__channel_retry_timeout = channel_retry_timeout + self.__idempotent_rest_publishing = idempotent_rest_publishing + self.__loop = loop + self.__auto_connect = auto_connect + self.__connection_state_ttl = connection_state_ttl + self.__suspended_retry_timeout = suspended_retry_timeout + self.__connectivity_check_url = connectivity_check_url + self.__fallback_realtime_host = None + self.__add_request_ids = add_request_ids + + self.__rest_hosts = self.__get_rest_hosts() + self.__realtime_hosts = self.__get_realtime_hosts() + + @property + def client_id(self): + return self.__client_id + + @client_id.setter + def client_id(self, value): + self.__client_id = value + + @property + def log_level(self): + return self.__log_level + + @log_level.setter + def log_level(self, value): + self.__log_level = value + + @property + def tls(self): + return self.__tls + + @tls.setter + def tls(self, value): + self.__tls = value + + @property + def rest_host(self): + return self.__rest_host + + @rest_host.setter + def rest_host(self, value): + self.__rest_host = value + + # RTC1d + @property + def realtime_host(self): + return self.__realtime_host + + @realtime_host.setter + def realtime_host(self, value): + self.__realtime_host = value + + @property + def port(self): + return self.__port + + @port.setter + def port(self, value): + self.__port = value + + @property + def tls_port(self): + return self.__tls_port + + @tls_port.setter + def tls_port(self, value): + self.__tls_port = value + + @property + def use_binary_protocol(self): + return self.__use_binary_protocol + + @use_binary_protocol.setter + def use_binary_protocol(self, value): + self.__use_binary_protocol = value + + @property + def queue_messages(self): + return self.__queue_messages + + @queue_messages.setter + def queue_messages(self, value): + self.__queue_messages = value + + @property + def recover(self): + return self.__recover + + @recover.setter + def recover(self, value): + self.__recover = value + + @property + def environment(self): + return self.__environment + + @property + def http_open_timeout(self): + return self.__http_open_timeout + + @http_open_timeout.setter + def http_open_timeout(self, value): + self.__http_open_timeout = value + + @property + def http_request_timeout(self): + return self.__http_request_timeout + + @property + def realtime_request_timeout(self): + return self.__realtime_request_timeout + + @http_request_timeout.setter + def http_request_timeout(self, value): + self.__http_request_timeout = value + + @property + def http_max_retry_count(self): + return self.__http_max_retry_count + + @http_max_retry_count.setter + def http_max_retry_count(self, value): + self.__http_max_retry_count = value + + @property + def http_max_retry_duration(self): + return self.__http_max_retry_duration + + @http_max_retry_duration.setter + def http_max_retry_duration(self, value): + self.__http_max_retry_duration = value + + @property + def fallback_hosts(self): + return self.__fallback_hosts + + @property + def fallback_retry_timeout(self): + return self.__fallback_retry_timeout + + @property + def disconnected_retry_timeout(self): + return self.__disconnected_retry_timeout + + @property + def channel_retry_timeout(self): + return self.__channel_retry_timeout + + @property + def idempotent_rest_publishing(self): + return self.__idempotent_rest_publishing + + @property + def loop(self): + return self.__loop + + # RTC1b + @property + def auto_connect(self): + return self.__auto_connect + + @property + def connection_state_ttl(self): + return self.__connection_state_ttl + + @connection_state_ttl.setter + def connection_state_ttl(self, value): + self.__connection_state_ttl = value + + @property + def suspended_retry_timeout(self): + return self.__suspended_retry_timeout + + @property + def connectivity_check_url(self): + return self.__connectivity_check_url + + @property + def fallback_realtime_host(self): + return self.__fallback_realtime_host + + @fallback_realtime_host.setter + def fallback_realtime_host(self, value): + self.__fallback_realtime_host = value + + @property + def add_request_ids(self): + return self.__add_request_ids + + def __get_rest_hosts(self): + """ + Return the list of hosts as they should be tried. First comes the main + host. Then the fallback hosts in random order. + The returned list will have a length of up to http_max_retry_count. + """ + # Defaults + host = self.rest_host + if host is None: + host = Defaults.rest_host + + environment = self.environment + + http_max_retry_count = self.http_max_retry_count + if http_max_retry_count is None: + http_max_retry_count = Defaults.http_max_retry_count + + # Prepend environment + if environment != 'production': + host = '%s-%s' % (environment, host) + + # Fallback hosts + fallback_hosts = self.fallback_hosts + if fallback_hosts is None: + if host == Defaults.rest_host: + fallback_hosts = Defaults.fallback_hosts + elif environment != 'production': + fallback_hosts = Defaults.get_environment_fallback_hosts(environment) + else: + fallback_hosts = [] + + # Shuffle + fallback_hosts = list(fallback_hosts) + random.shuffle(fallback_hosts) + self.__fallback_hosts = fallback_hosts + + # First main host + hosts = [host] + fallback_hosts + hosts = hosts[:http_max_retry_count] + return hosts + + def __get_realtime_hosts(self): + if self.realtime_host is not None: + host = self.realtime_host + return [host] + elif self.environment != "production": + host = f'{self.environment}-{Defaults.realtime_host}' + else: + host = Defaults.realtime_host + + return [host] + self.__fallback_hosts + + def get_rest_hosts(self): + return self.__rest_hosts + + def get_rest_host(self): + return self.__rest_hosts[0] + + def get_realtime_hosts(self): + return self.__realtime_hosts + + def get_realtime_host(self): + return self.__realtime_hosts[0] + + def get_fallback_rest_hosts(self): + return self.__rest_hosts[1:] + + def get_fallback_realtime_hosts(self): + return self.__realtime_hosts[1:] diff --git a/ably/sync/types/presence.py b/ably/sync/types/presence.py new file mode 100644 index 00000000..112c619c --- /dev/null +++ b/ably/sync/types/presence.py @@ -0,0 +1,174 @@ +from datetime import datetime, timedelta +from urllib import parse + +from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.types.mixins import EncodeDataMixin + + +def _ms_since_epoch(dt): + epoch = datetime.utcfromtimestamp(0) + delta = dt - epoch + return int(delta.total_seconds() * 1000) + + +def _dt_from_ms_epoch(ms): + epoch = datetime.utcfromtimestamp(0) + return epoch + timedelta(milliseconds=ms) + + +class PresenceAction: + ABSENT = 0 + PRESENT = 1 + ENTER = 2 + LEAVE = 3 + UPDATE = 4 + + +class PresenceMessage(EncodeDataMixin): + + def __init__(self, + id=None, # TP3a + action=None, # TP3b + client_id=None, # TP3c + connection_id=None, # TP3d + data=None, # TP3e + encoding=None, # TP3f + timestamp=None, # TP3g + member_key=None, # TP3h (for RT only) + extras=None, # TP3i (functionality not specified) + ): + + self.__id = id + self.__action = action + self.__client_id = client_id + self.__connection_id = connection_id + self.__data = data + self.__encoding = encoding + self.__timestamp = timestamp + self.__member_key = member_key + self.__extras = extras + + @property + def id(self): + return self.__id + + @property + def action(self): + return self.__action + + @property + def client_id(self): + return self.__client_id + + @property + def connection_id(self): + return self.__connection_id + + @property + def data(self): + return self.__data + + @property + def encoding(self): + return self.__encoding + + @property + def timestamp(self): + return self.__timestamp + + @property + def member_key(self): + if self.connection_id and self.client_id: + return "%s:%s" % (self.connection_id, self.client_id) + + @property + def extras(self): + return self.__extras + + @staticmethod + def from_encoded(obj, cipher=None): + id = obj.get('id') + action = obj.get('action', PresenceAction.ENTER) + client_id = obj.get('clientId') + connection_id = obj.get('connectionId') + data = obj.get('data') + encoding = obj.get('encoding', '') + timestamp = obj.get('timestamp') + # member_key = obj.get('memberKey', None) + extras = obj.get('extras', None) + + if timestamp is not None: + timestamp = _dt_from_ms_epoch(timestamp) + + decoded_data = PresenceMessage.decode(data, encoding, cipher) + + return PresenceMessage( + id=id, + action=action, + client_id=client_id, + connection_id=connection_id, + timestamp=timestamp, + extras=extras, + **decoded_data + ) + + +class Presence: + def __init__(self, channel): + self.__base_path = '/channels/%s/' % parse.quote_plus(channel.name) + self.__binary = channel.ably.options.use_binary_protocol + self.__http = channel.ably.http + self.__cipher = channel.cipher + + def _path_with_qs(self, rel_path, qs=None): + path = rel_path + if qs: + path += ('?' + parse.urlencode(qs)) + return path + + def get(self, limit=None): + qs = {} + if limit: + if limit > 1000: + raise ValueError("The maximum allowed limit is 1000") + qs['limit'] = limit + path = self._path_with_qs(self.__base_path + 'presence', qs) + + presence_handler = make_presence_response_handler(self.__cipher) + return PaginatedResult.paginated_query( + self.__http, url=path, response_processor=presence_handler) + + def history(self, limit=None, direction=None, start=None, end=None): + qs = {} + if limit: + if limit > 1000: + raise ValueError("The maximum allowed limit is 1000") + qs['limit'] = limit + if direction: + qs['direction'] = direction + if start: + if isinstance(start, int): + qs['start'] = start + else: + qs['start'] = _ms_since_epoch(start) + if end: + if isinstance(end, int): + qs['end'] = end + else: + qs['end'] = _ms_since_epoch(end) + + if 'start' in qs and 'end' in qs and qs['start'] > qs['end']: + raise ValueError("'end' parameter has to be greater than or equal to 'start'") + + path = self._path_with_qs(self.__base_path + 'presence/history', qs) + + presence_handler = make_presence_response_handler(self.__cipher) + return PaginatedResult.paginated_query( + self.__http, url=path, response_processor=presence_handler) + + +def make_presence_response_handler(cipher): + def encrypted_presence_response_handler(response): + messages = response.to_native() + return PresenceMessage.from_encoded_array(messages, cipher=cipher) + return encrypted_presence_response_handler diff --git a/ably/sync/types/stats.py b/ably/sync/types/stats.py new file mode 100644 index 00000000..ead5e548 --- /dev/null +++ b/ably/sync/types/stats.py @@ -0,0 +1,67 @@ +import logging +from datetime import datetime + +log = logging.getLogger(__name__) + + +class Stats: + + def __init__(self, entries=None, unit=None, interval_id=None, in_progress=None, app_id=None, schema=None): + self.interval_id = interval_id or '' + self.entries = entries + self.unit = unit + self.interval_time = interval_from_interval_id(self.interval_id) + self.in_progress = in_progress + self.app_id = app_id + self.schema = schema + + @classmethod + def from_dict(cls, stats_dict): + stats_dict = stats_dict or {} + + kwargs = { + "entries": stats_dict.get("entries"), + "unit": stats_dict.get("unit"), + "interval_id": stats_dict.get("intervalId"), + "in_progress": stats_dict.get("inProgress"), + "app_id": stats_dict.get("appId"), + "schema": stats_dict.get("schema"), + } + + return cls(**kwargs) + + @classmethod + def from_array(cls, stats_array): + return [cls.from_dict(d) for d in stats_array] + + @staticmethod + def to_interval_id(date_time, granularity): + return date_time.strftime(INTERVALS_FMT[granularity]) + + +def stats_response_processor(response): + stats_array = response.to_native() + return Stats.from_array(stats_array) + + +INTERVALS_FMT = { + 'minute': '%Y-%m-%d:%H:%M', + 'hour': '%Y-%m-%d:%H', + 'day': '%Y-%m-%d', + 'month': '%Y-%m', +} + + +def granularity_from_interval_id(interval_id): + for key, value in INTERVALS_FMT.items(): + try: + datetime.strptime(interval_id, value) + return key + except ValueError: + pass + raise ValueError("Unsupported intervalId") + + +def interval_from_interval_id(interval_id): + granularity = granularity_from_interval_id(interval_id) + return datetime.strptime(interval_id, INTERVALS_FMT[granularity]) diff --git a/ably/sync/types/tokendetails.py b/ably/sync/types/tokendetails.py new file mode 100644 index 00000000..4a898a5b --- /dev/null +++ b/ably/sync/types/tokendetails.py @@ -0,0 +1,97 @@ +import json +import time + +from ably.sync.types.capability import Capability + + +class TokenDetails: + + DEFAULTS = {'ttl': 60 * 60 * 1000} + # Buffer in milliseconds before a token is considered unusable + # For example, if buffer is 10000ms, the token can no longer be used for + # new requests 9000ms before it expires + TOKEN_EXPIRY_BUFFER = 15 * 1000 + + def __init__(self, token=None, expires=None, issued=0, + capability=None, client_id=None): + if expires is None: + self.__expires = time.time() * 1000 + TokenDetails.DEFAULTS['ttl'] + else: + self.__expires = expires + self.__token = token + self.__issued = issued + if capability and isinstance(capability, str): + try: + self.__capability = Capability(json.loads(capability)) + except json.JSONDecodeError: + self.__capability = Capability(json.loads(capability.replace("'", '"'))) + else: + self.__capability = Capability(capability or {}) + self.__client_id = client_id + + @property + def token(self): + return self.__token + + @property + def expires(self): + return self.__expires + + @property + def issued(self): + return self.__issued + + @property + def capability(self): + return self.__capability + + @property + def client_id(self): + return self.__client_id + + def to_dict(self): + return { + 'expires': self.expires, + 'token': self.token, + 'issued': self.issued, + 'capability': self.capability.to_dict(), + 'clientId': self.client_id, + } + + @staticmethod + def from_dict(obj): + kwargs = { + 'token': obj.get("token"), + 'capability': obj.get("capability"), + 'client_id': obj.get("clientId") + } + expires = obj.get("expires") + kwargs['expires'] = expires if expires is None else int(expires) + issued = obj.get("issued") + kwargs['issued'] = issued if issued is None else int(issued) + + return TokenDetails(**kwargs) + + @staticmethod + def from_json(data): + if isinstance(data, str): + data = json.loads(data) + + mapping = { + 'clientId': 'client_id', + } + for name in data: + py_name = mapping.get(name) + if py_name: + data[py_name] = data.pop(name) + + return TokenDetails(**data) + + def __eq__(self, other): + if isinstance(other, TokenDetails): + return (self.expires == other.expires + and self.token == other.token + and self.issued == other.issued + and self.capability == other.capability + and self.client_id == other.client_id) + return NotImplemented diff --git a/ably/sync/types/tokenrequest.py b/ably/sync/types/tokenrequest.py new file mode 100644 index 00000000..d10a5eb3 --- /dev/null +++ b/ably/sync/types/tokenrequest.py @@ -0,0 +1,107 @@ +import base64 +import hashlib +import hmac +import json + + +class TokenRequest: + + def __init__(self, key_name=None, client_id=None, nonce=None, mac=None, + capability=None, ttl=None, timestamp=None): + self.__key_name = key_name + self.__client_id = client_id + self.__nonce = nonce + self.__mac = mac + self.__capability = capability + self.__ttl = ttl + self.__timestamp = timestamp + + def sign_request(self, key_secret): + sign_text = "\n".join([str(x) for x in [ + self.key_name or "", + self.ttl or "", + self.capability or "", + self.client_id or "", + "%d" % (self.timestamp or 0), + self.nonce or "", + "", # to get the trailing new line + ]]) + try: + key_secret = key_secret.encode('utf8') + except AttributeError: + pass + try: + sign_text = sign_text.encode('utf8') + except AttributeError: + pass + mac = hmac.new(key_secret, sign_text, hashlib.sha256).digest() + self.mac = base64.b64encode(mac).decode('utf8') + + def to_dict(self): + return { + 'keyName': self.key_name, + 'clientId': self.client_id, + 'ttl': self.ttl, + 'nonce': self.nonce, + 'capability': self.capability, + 'timestamp': self.timestamp, + 'mac': self.mac + } + + @staticmethod + def from_json(data): + if isinstance(data, str): + data = json.loads(data) + + mapping = { + 'keyName': 'key_name', + 'clientId': 'client_id', + } + for name, py_name in mapping.items(): + if name in data: + data[py_name] = data.pop(name) + + return TokenRequest(**data) + + def __eq__(self, other): + if isinstance(other, TokenRequest): + return (self.key_name == other.key_name + and self.client_id == other.client_id + and self.nonce == other.nonce + and self.mac == other.mac + and self.capability == other.capability + and self.ttl == other.ttl + and self.timestamp == other.timestamp) + return NotImplemented + + @property + def key_name(self): + return self.__key_name + + @property + def client_id(self): + return self.__client_id + + @property + def nonce(self): + return self.__nonce + + @property + def mac(self): + return self.__mac + + @mac.setter + def mac(self, mac): + self.__mac = mac + + @property + def capability(self): + return self.__capability + + @property + def ttl(self): + return self.__ttl + + @property + def timestamp(self): + return self.__timestamp diff --git a/ably/sync/types/typedbuffer.py b/ably/sync/types/typedbuffer.py new file mode 100644 index 00000000..56adcd88 --- /dev/null +++ b/ably/sync/types/typedbuffer.py @@ -0,0 +1,104 @@ +# This functionality is depreceated and will be removed +# Message Pack is the replacement for all binary data messages + +import json +import struct + + +class DataType: + NONE = 0 + TRUE = 1 + FALSE = 2 + INT32 = 3 + INT64 = 4 + DOUBLE = 5 + STRING = 6 + BUFFER = 7 + JSONARRAY = 8 + JSONOBJECT = 9 + + +class Limits: + INT32_MAX = 2 ** 31 + INT32_MIN = -(2 ** 31 + 1) + INT64_MAX = 2 ** 63 + INT64_MIN = - (2 ** 63 + 1) + + +_decoders = {DataType.TRUE: lambda b: True, + DataType.FALSE: lambda b: False, + DataType.INT32: lambda b: struct.unpack('>i', b)[0], + DataType.INT64: lambda b: struct.unpack('>q', b)[0], + DataType.DOUBLE: lambda b: struct.unpack('>d', b)[0], + DataType.STRING: lambda b: b.decode('utf-8'), + DataType.BUFFER: lambda b: b, + DataType.JSONARRAY: lambda b: json.loads(b.decode('utf-8')), + DataType.JSONOBJECT: lambda b: json.loads(b.decode('utf-8'))} + + +class TypedBuffer: + def __init__(self, buffer, type): + self.__buffer = buffer + self.__type = type + + def __eq__(self, other): + if isinstance(other, TypedBuffer): + return self.buffer == other.buffer and self.type == other.type + return NotImplemented + + def __ne__(self, other): + if isinstance(other, TypedBuffer): + result = self.__eq__(other) + if result != NotImplemented: + return not result + return NotImplemented + + @staticmethod + def from_obj(obj): + if isinstance(obj, TypedBuffer): + return obj + elif isinstance(obj, (bytes, bytearray)): + data_type = DataType.BUFFER + buffer = obj + elif isinstance(obj, str): + data_type = DataType.STRING + buffer = obj.encode('utf-8') + elif isinstance(obj, bool): + data_type = DataType.TRUE if obj else DataType.FALSE + buffer = None + elif isinstance(obj, int): + if Limits.INT32_MIN <= obj <= Limits.INT32_MAX: + data_type = DataType.INT32 + buffer = struct.pack('>i', obj) + elif Limits.INT64_MIN <= obj <= Limits.INT64_MAX: + data_type = DataType.INT64 + buffer = struct.pack('>q', obj) + else: + raise ValueError('Number too large %d' % obj) + elif isinstance(obj, float): + data_type = DataType.DOUBLE + buffer = struct.pack('>d', obj) + elif isinstance(obj, list): + data_type = DataType.JSONARRAY + buffer = json.dumps(obj, separators=(',', ':')).encode('utf-8') + elif isinstance(obj, dict): + data_type = DataType.JSONOBJECT + buffer = json.dumps(obj, separators=(',', ':')).encode('utf-8') + else: + raise TypeError('Unexpected object type %s' % type(obj)) + + return TypedBuffer(buffer, data_type) + + @property + def buffer(self): + return self.__buffer + + @property + def type(self): + return self.__type + + def decode(self): + decoder = _decoders.get(self.type) + if decoder is not None: + return decoder(self.buffer) + raise ValueError('Unsupported data type %s' % self.type) diff --git a/ably/sync/util/__init__.py b/ably/sync/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ably/sync/util/case.py b/ably/sync/util/case.py new file mode 100644 index 00000000..3b18c49e --- /dev/null +++ b/ably/sync/util/case.py @@ -0,0 +1,18 @@ +import re + + +first_cap_re = re.compile('(.)([A-Z][a-z]+)') +all_cap_re = re.compile('([a-z0-9])([A-Z])') + + +def camel_to_snake(name): + s1 = first_cap_re.sub(r'\1_\2', name) + return all_cap_re.sub(r'\1_\2', s1).lower() + + +def snake_to_camel(name): + name = name.split('_') + for i in range(1, len(name)): + name[i] = name[i].title() + + return ''.join(name) diff --git a/ably/sync/util/crypto.py b/ably/sync/util/crypto.py new file mode 100644 index 00000000..bf1a9a35 --- /dev/null +++ b/ably/sync/util/crypto.py @@ -0,0 +1,179 @@ +import base64 +import logging + +try: + from Crypto.Cipher import AES + from Crypto import Random +except ImportError: + from .nocrypto import AES, Random + +from ably.sync.types.typedbuffer import TypedBuffer +from ably.sync.util.exceptions import AblyException + +log = logging.getLogger(__name__) + + +class CipherParams: + def __init__(self, algorithm='AES', mode='CBC', secret_key=None, iv=None): + self.__algorithm = algorithm.upper() + self.__secret_key = secret_key + self.__key_length = len(secret_key) * 8 if secret_key is not None else 128 + self.__mode = mode.upper() + self.__iv = iv + + @property + def algorithm(self): + return self.__algorithm + + @property + def secret_key(self): + return self.__secret_key + + @property + def iv(self): + return self.__iv + + @property + def key_length(self): + return self.__key_length + + @property + def mode(self): + return self.__mode + + +class CbcChannelCipher: + def __init__(self, cipher_params): + self.__secret_key = (cipher_params.secret_key or + self.__random(cipher_params.key_length / 8)) + if isinstance(self.__secret_key, str): + self.__secret_key = self.__secret_key.encode() + self.__iv = cipher_params.iv or self.__random(16) + self.__block_size = len(self.__iv) + if cipher_params.algorithm != 'AES': + raise NotImplementedError('Only AES algorithm is supported') + self.__algorithm = cipher_params.algorithm + if cipher_params.mode != 'CBC': + raise NotImplementedError('Only CBC mode is supported') + self.__mode = cipher_params.mode + self.__key_length = cipher_params.key_length + self.__encryptor = AES.new(self.__secret_key, AES.MODE_CBC, self.__iv) + + def __pad(self, data): + padding_size = self.__block_size - (len(data) % self.__block_size) + + padding_char = bytes((padding_size,)) + padded = data + padding_char * padding_size + + return padded + + def __unpad(self, data): + padding_size = data[-1] + + if padding_size > len(data): + # Too short + raise AblyException('invalid-padding', 0, 0) + + if padding_size == 0: + # Missing padding + raise AblyException('invalid-padding', 0, 0) + + for i in range(padding_size): + # Invalid padding bytes + if padding_size != data[-i - 1]: + raise AblyException('invalid-padding', 0, 0) + + return data[:-padding_size] + + def __random(self, length): + rndfile = Random.new() + return rndfile.read(length) + + def encrypt(self, plaintext): + if isinstance(plaintext, bytearray): + plaintext = bytes(plaintext) + padded_plaintext = self.__pad(plaintext) + encrypted = self.__iv + self.__encryptor.encrypt(padded_plaintext) + self.__iv = encrypted[-self.__block_size:] + return encrypted + + def decrypt(self, ciphertext): + if isinstance(ciphertext, bytearray): + ciphertext = bytes(ciphertext) + iv = ciphertext[:self.__block_size] + ciphertext = ciphertext[self.__block_size:] + decryptor = AES.new(self.__secret_key, AES.MODE_CBC, iv) + decrypted = decryptor.decrypt(ciphertext) + return bytearray(self.__unpad(decrypted)) + + @property + def secret_key(self): + return self.__secret_key + + @property + def iv(self): + return self.__iv + + @property + def cipher_type(self): + return ("%s-%s-%s" % (self.__algorithm, self.__key_length, + self.__mode)).lower() + + +class CipherData(TypedBuffer): + ENCODING_ID = 'cipher' + + def __init__(self, buffer, type, cipher_type=None, **kwargs): + self.__cipher_type = cipher_type + super().__init__(buffer, type, **kwargs) + + @property + def encoding_str(self): + return self.ENCODING_ID + '+' + self.__cipher_type + + +DEFAULT_KEYLENGTH = 256 +DEFAULT_BLOCKLENGTH = 16 + + +def generate_random_key(length=DEFAULT_KEYLENGTH): + rndfile = Random.new() + return rndfile.read(length // 8) + + +def get_default_params(params=None): + if type(params) in [str, bytes]: + raise ValueError("Calling get_default_params with a key directly is deprecated, it expects a params dict") + + key = params.get('key') + algorithm = params.get('algorithm') or 'AES' + iv = params.get('iv') or generate_random_key(DEFAULT_BLOCKLENGTH * 8) + mode = params.get('mode') or 'CBC' + + if not key: + raise ValueError("Crypto.get_default_params: a key is required") + + if type(key) == str: + key = base64.b64decode(key) + + cipher_params = CipherParams(algorithm=algorithm, secret_key=key, iv=iv, mode=mode) + validate_cipher_params(cipher_params) + return cipher_params + + +def get_cipher(params): + if isinstance(params, CipherParams): + cipher_params = params + else: + cipher_params = get_default_params(params) + return CbcChannelCipher(cipher_params) + + +def validate_cipher_params(cipher_params): + if cipher_params.algorithm == 'AES' and cipher_params.mode == 'CBC': + key_length = cipher_params.key_length + if key_length == 128 or key_length == 256: + return + raise ValueError( + 'Unsupported key length %s for aes-cbc encryption. Encryption key must be 128 or 256 bits' + ' (16 or 32 ASCII characters)' % key_length) diff --git a/ably/sync/util/eventemitter.py b/ably/sync/util/eventemitter.py new file mode 100644 index 00000000..47c139db --- /dev/null +++ b/ably/sync/util/eventemitter.py @@ -0,0 +1,185 @@ +import asyncio +import logging +from pyee.asyncio import AsyncIOEventEmitter + +from ably.sync.util.helper import is_callable_or_coroutine + +# pyee's event emitter doesn't support attaching a listener to all events +# so to patch it, we create a wrapper which uses two event emitters, one +# is used to listen to all events and this arbitrary string is the event name +# used to emit all events on that listener +_all_event = 'all' + +log = logging.getLogger(__name__) + + +def _is_named_event_args(*args): + return len(args) == 2 and is_callable_or_coroutine(args[1]) + + +def _is_all_event_args(*args): + return len(args) == 1 and is_callable_or_coroutine(args[0]) + + +class EventEmitter: + """ + A generic interface for event registration and delivery used in a number of the types in the Realtime client + library. For example, the Connection object emits events for connection state using the EventEmitter pattern. + + Methods + ------- + on(*args) + Attach to channel + once(*args) + Detach from channel + off() + Subscribe to messages on a channel + """ + + def __init__(self): + self.__named_event_emitter = AsyncIOEventEmitter() + self.__all_event_emitter = AsyncIOEventEmitter() + self.__wrapped_listeners = {} + + def on(self, *args): + """ + Registers the provided listener for the specified event, if provided, and otherwise for all events. + If on() is called more than once with the same listener and event, the listener is added multiple times to + its listener registry. Therefore, as an example, assuming the same listener is registered twice using + on(), and an event is emitted once, the listener would be invoked twice. + + Parameters + ---------- + name : str + The named event to listen for. + listener : callable + The event listener. + """ + if _is_all_event_args(*args): + event = _all_event + listener = args[0] + emitter = self.__all_event_emitter + # self.__all_event_emitter.add_listener(_all_event, args[0]) + elif _is_named_event_args(*args): + event = args[0] + listener = args[1] + emitter = self.__named_event_emitter + # self.__named_event_emitter.add_listener(args[0], args[1]) + else: + raise ValueError("EventEmitter.on(): invalid args") + + if asyncio.iscoroutinefunction(listener): + def wrapped_listener(*args, **kwargs): + try: + listener(*args, **kwargs) + except Exception as err: + log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') + else: + def wrapped_listener(*args, **kwargs): + try: + listener(*args, **kwargs) + except Exception as err: + log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') + + self.__wrapped_listeners[listener] = wrapped_listener + + emitter.add_listener(event, wrapped_listener) + + def once(self, *args): + """ + Registers the provided listener for the first event that is emitted. If once() is called more than once + with the same listener, the listener is added multiple times to its listener registry. Therefore, as an + example, assuming the same listener is registered twice using once(), and an event is emitted once, the + listener would be invoked twice. However, all subsequent events emitted would not invoke the listener as + once() ensures that each registration is only invoked once. + + Parameters + ---------- + name : str + The named event to listen for. + listener : callable + The event listener. + """ + if _is_all_event_args(*args): + event = _all_event + listener = args[0] + emitter = self.__all_event_emitter + # self.__all_event_emitter.add_listener(_all_event, args[0]) + elif _is_named_event_args(*args): + event = args[0] + listener = args[1] + emitter = self.__named_event_emitter + # self.__named_event_emitter.add_listener(args[0], args[1]) + else: + raise ValueError("EventEmitter.on(): invalid args") + + if asyncio.iscoroutinefunction(listener): + def wrapped_listener(*args, **kwargs): + try: + listener(*args, **kwargs) + except Exception as err: + log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') + else: + def wrapped_listener(*args, **kwargs): + try: + listener(*args, **kwargs) + except Exception as err: + log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') + + self.__wrapped_listeners[listener] = wrapped_listener + + emitter.once(event, wrapped_listener) + + def off(self, *args): + """ + Removes all registrations that match both the specified listener and, if provided, the specified event. + If called with no arguments, deregisters all registrations, for all events and listeners. + + Parameters + ---------- + name : str + The named event to listen for. + listener : callable + The event listener. + """ + if len(args) == 0: + self.__all_event_emitter.remove_all_listeners() + self.__named_event_emitter.remove_all_listeners() + return + elif _is_all_event_args(*args): + event = _all_event + listener = args[0] + emitter = self.__all_event_emitter + elif _is_named_event_args(*args): + event = args[0] + listener = args[1] + emitter = self.__named_event_emitter + else: + raise ValueError("EventEmitter.once(): invalid args") + + wrapped_listener = self.__wrapped_listeners.get(listener) + + if wrapped_listener is None: + return + + emitter.remove_listener(event, wrapped_listener) + self.__wrapped_listeners[listener] = None + + def once_async(self, state=None): + future = asyncio.Future() + + def on_state_change(*args): + future.set_result(*args) + + if state is not None: + self.once(state, on_state_change) + else: + self.once(on_state_change) + + state_change = future + + return state_change + + def _emit(self, *args): + self.__named_event_emitter.emit(*args) + self.__all_event_emitter.emit(_all_event, *args[1:]) diff --git a/ably/sync/util/exceptions.py b/ably/sync/util/exceptions.py new file mode 100644 index 00000000..090cf3d8 --- /dev/null +++ b/ably/sync/util/exceptions.py @@ -0,0 +1,92 @@ +import functools +import logging + + +log = logging.getLogger(__name__) + + +class AblyException(Exception): + def __new__(cls, message, status_code, code, cause=None): + if cls == AblyException and status_code == 401: + return AblyAuthException(message, status_code, code, cause) + return super().__new__(cls, message, status_code, code, cause) + + def __init__(self, message, status_code, code, cause=None): + super().__init__() + self.message = message + self.code = code + self.status_code = status_code + self.cause = cause + + def __str__(self): + str = '%s %s %s' % (self.code, self.status_code, self.message) + if self.cause is not None: + str += ' (cause: %s)' % self.cause + return str + + @property + def is_server_error(self): + return 500 <= self.status_code <= 599 + + @staticmethod + def raise_for_response(response): + if 200 <= response.status_code < 300: + # Valid response + return + + try: + json_response = response.json() + except Exception: + log.debug("Response not json: %d %s", + response.status_code, + response.text) + raise AblyException(message=response.text, + status_code=response.status_code, + code=response.status_code * 100) + + if json_response and 'error' in json_response: + error = json_response['error'] + try: + raise AblyException( + message=error['message'], + status_code=error['statusCode'], + code=int(error['code']), + ) + except KeyError: + msg = "Unexpected exception decoding server response: %s" + msg = msg % response.text + raise AblyException(message=msg, status_code=500, code=50000) + + raise AblyException(message="", + status_code=response.status_code, + code=response.status_code * 100) + + @staticmethod + def from_exception(e): + if isinstance(e, AblyException): + return e + return AblyException("Unexpected exception: %s" % e, 500, 50000) + + @staticmethod + def from_dict(value: dict): + return AblyException(value.get('message'), value.get('statusCode'), value.get('code')) + + +def catch_all(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + log.exception(e) + raise AblyException.from_exception(e) + + return wrapper + + +class AblyAuthException(AblyException): + pass + + +class IncompatibleClientIdException(AblyException): + pass diff --git a/ably/sync/util/helper.py b/ably/sync/util/helper.py new file mode 100644 index 00000000..a844204e --- /dev/null +++ b/ably/sync/util/helper.py @@ -0,0 +1,42 @@ +import inspect +import random +import string +import asyncio +import time +from typing import Callable + + +def get_random_id(): + # get random string of letters and digits + source = string.ascii_letters + string.digits + random_id = ''.join((random.choice(source) for i in range(8))) + return random_id + + +def is_callable_or_coroutine(value): + return asyncio.iscoroutinefunction(value) or inspect.isfunction(value) or inspect.ismethod(value) + + +def unix_time_ms(): + return round(time.time_ns() / 1_000_000) + + +def is_token_error(exception): + return 40140 <= exception.code < 40150 + + +class Timer: + def __init__(self, timeout: float, callback: Callable): + self._timeout = timeout + self._callback = callback + self._task = asyncio.create_task(self._job()) + + def _job(self): + asyncio.sleep(self._timeout / 1000) + if asyncio.iscoroutinefunction(self._callback): + self._callback() + else: + self._callback() + + def cancel(self): + self._task.cancel() diff --git a/ably/sync/util/nocrypto.py b/ably/sync/util/nocrypto.py new file mode 100644 index 00000000..a66669b3 --- /dev/null +++ b/ably/sync/util/nocrypto.py @@ -0,0 +1,9 @@ + +class InstallPycrypto: + def __getattr__(self, name): + raise ImportError( + "This requires to install ably with crypto support: pip install 'ably[crypto]'" + ) + + +AES = Random = InstallPycrypto() From c9f238642655f5b3addf16bd2bfb819c7b3d1ca5 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 15:03:48 +0530 Subject: [PATCH 14/52] Created sync directory under test to maintain ably rest test code --- test/ably/rest/restcrypto_test.py | 528 +++++++------- test/ably/sync/rest/encoders_test.py | 456 ++++++++++++ test/ably/sync/rest/restauth_test.py | 652 ++++++++++++++++++ test/ably/sync/rest/restcapability_test.py | 243 +++++++ .../ably/sync/rest/restchannelhistory_test.py | 332 +++++++++ .../ably/sync/rest/restchannelpublish_test.py | 568 +++++++++++++++ test/ably/sync/rest/restchannels_test.py | 91 +++ test/ably/sync/rest/restchannelstatus_test.py | 47 ++ test/ably/sync/rest/restcrypto_test.py | 264 +++++++ test/ably/sync/rest/resthttp_test.py | 229 ++++++ test/ably/sync/rest/restinit_test.py | 227 ++++++ .../sync/rest/restpaginatedresult_test.py | 91 +++ test/ably/sync/rest/restpresence_test.py | 213 ++++++ test/ably/sync/rest/restpush_test.py | 398 +++++++++++ test/ably/sync/rest/restrequest_test.py | 132 ++++ test/ably/sync/rest/reststats_test.py | 310 +++++++++ test/ably/sync/rest/resttime_test.py | 43 ++ test/ably/sync/rest/resttoken_test.py | 342 +++++++++ test/ably/sync/testapp.py | 115 +++ test/ably/sync/utils.py | 168 +++++ 20 files changed, 5185 insertions(+), 264 deletions(-) create mode 100644 test/ably/sync/rest/encoders_test.py create mode 100644 test/ably/sync/rest/restauth_test.py create mode 100644 test/ably/sync/rest/restcapability_test.py create mode 100644 test/ably/sync/rest/restchannelhistory_test.py create mode 100644 test/ably/sync/rest/restchannelpublish_test.py create mode 100644 test/ably/sync/rest/restchannels_test.py create mode 100644 test/ably/sync/rest/restchannelstatus_test.py create mode 100644 test/ably/sync/rest/restcrypto_test.py create mode 100644 test/ably/sync/rest/resthttp_test.py create mode 100644 test/ably/sync/rest/restinit_test.py create mode 100644 test/ably/sync/rest/restpaginatedresult_test.py create mode 100644 test/ably/sync/rest/restpresence_test.py create mode 100644 test/ably/sync/rest/restpush_test.py create mode 100644 test/ably/sync/rest/restrequest_test.py create mode 100644 test/ably/sync/rest/reststats_test.py create mode 100644 test/ably/sync/rest/resttime_test.py create mode 100644 test/ably/sync/rest/resttoken_test.py create mode 100644 test/ably/sync/testapp.py create mode 100644 test/ably/sync/utils.py diff --git a/test/ably/rest/restcrypto_test.py b/test/ably/rest/restcrypto_test.py index 18bf69ac..3dd89bc2 100644 --- a/test/ably/rest/restcrypto_test.py +++ b/test/ably/rest/restcrypto_test.py @@ -1,264 +1,264 @@ -import json -import os -import logging -import base64 - -import pytest - -from ably import AblyException -from ably.types.message import Message -from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params - -from Crypto import Random - -from test.ably.testapp import TestApp -from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - async def asyncSetUp(self): - self.test_vars = await TestApp.get_test_vars() - self.ably = await TestApp.get_ably_rest() - self.ably2 = await TestApp.get_ably_rest() - - async def asyncTearDown(self): - await self.ably.close() - await self.ably2.close() - - def per_protocol_setup(self, use_binary_protocol): - # This will be called every test that vary by protocol for each protocol - self.ably.options.use_binary_protocol = use_binary_protocol - self.ably2.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - @dont_vary_protocol - def test_cbc_channel_cipher(self): - key = ( - b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' - b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') - - iv = ( - b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' - b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') - - log.debug("KEY_LEN: %d" % len(key)) - log.debug("IV_LEN: %d" % len(iv)) - cipher = get_cipher({'key': key, 'iv': iv}) - - plaintext = b"The quick brown fox" - expected_ciphertext = ( - b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' - b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' - b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' - b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' - b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' - b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') - - actual_ciphertext = cipher.encrypt(plaintext) - - assert expected_ciphertext == actual_ciphertext - - async def test_crypto_publish(self): - channel_name = self.get_channel_name('persisted:crypto_publish_text') - publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) - - await publish0.publish("publish3", "This is a string message payload") - await publish0.publish("publish4", b"This is a byte[] message payload") - await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) - await publish0.publish("publish6", ["This is a JSONArray message payload"]) - - history = await publish0.history() - messages = history.items - assert messages is not None, "Expected non-None messages" - assert 4 == len(messages), "Expected 4 messages" - - message_contents = dict((m.name, m.data) for m in messages) - log.debug("message_contents: %s" % str(message_contents)) - - assert "This is a string message payload" == message_contents["publish3"],\ - "Expect publish3 to be expected String)" - - assert b"This is a byte[] message payload" == message_contents["publish4"],\ - "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) - - assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ - "Expect publish5 to be expected JSONObject" - - assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ - "Expect publish6 to be expected JSONObject" - - async def test_crypto_publish_256(self): - rndfile = Random.new() - key = rndfile.read(32) - channel_name = 'persisted:crypto_publish_text_256' - channel_name += '_bin' if self.use_binary_protocol else '_text' - - publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) - - await publish0.publish("publish3", "This is a string message payload") - await publish0.publish("publish4", b"This is a byte[] message payload") - await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) - await publish0.publish("publish6", ["This is a JSONArray message payload"]) - - history = await publish0.history() - messages = history.items - assert messages is not None, "Expected non-None messages" - assert 4 == len(messages), "Expected 4 messages" - - message_contents = dict((m.name, m.data) for m in messages) - log.debug("message_contents: %s" % str(message_contents)) - - assert "This is a string message payload" == message_contents["publish3"],\ - "Expect publish3 to be expected String)" - - assert b"This is a byte[] message payload" == message_contents["publish4"],\ - "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) - - assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ - "Expect publish5 to be expected JSONObject" - - assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ - "Expect publish6 to be expected JSONObject" - - async def test_crypto_publish_key_mismatch(self): - channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') - - publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) - - await publish0.publish("publish3", "This is a string message payload") - await publish0.publish("publish4", b"This is a byte[] message payload") - await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) - await publish0.publish("publish6", ["This is a JSONArray message payload"]) - - rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) - - with pytest.raises(AblyException) as excinfo: - await rx_channel.history() - - message = excinfo.value.message - assert 'invalid-padding' == message or "codec can't decode" in message - - async def test_crypto_send_unencrypted(self): - channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') - publish0 = self.ably.channels[channel_name] - - await publish0.publish("publish3", "This is a string message payload") - await publish0.publish("publish4", b"This is a byte[] message payload") - await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) - await publish0.publish("publish6", ["This is a JSONArray message payload"]) - - rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) - - history = await rx_channel.history() - messages = history.items - assert messages is not None, "Expected non-None messages" - assert 4 == len(messages), "Expected 4 messages" - - message_contents = dict((m.name, m.data) for m in messages) - log.debug("message_contents: %s" % str(message_contents)) - - assert "This is a string message payload" == message_contents["publish3"],\ - "Expect publish3 to be expected String" - - assert b"This is a byte[] message payload" == message_contents["publish4"],\ - "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) - - assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ - "Expect publish5 to be expected JSONObject" - - assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ - "Expect publish6 to be expected JSONObject" - - async def test_crypto_encrypted_unhandled(self): - channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') - key = b'0123456789abcdef' - data = 'foobar' - publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) - - await publish0.publish("publish0", data) - - rx_channel = self.ably2.channels[channel_name] - history = await rx_channel.history() - message = history.items[0] - cipher = get_cipher(get_default_params({'key': key})) - assert cipher.decrypt(message.data).decode() == data - assert message.encoding == 'utf-8/cipher+aes-128-cbc' - - @dont_vary_protocol - def test_cipher_params(self): - params = CipherParams(secret_key='0123456789abcdef') - assert params.algorithm == 'AES' - assert params.mode == 'CBC' - assert params.key_length == 128 - - params = CipherParams(secret_key='0123456789abcdef' * 2) - assert params.algorithm == 'AES' - assert params.mode == 'CBC' - assert params.key_length == 256 - - -class AbstractTestCryptoWithFixture: - - @classmethod - def setUpClass(cls): - resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file - with open(resources_path, 'r') as f: - cls.fixture = json.loads(f.read()) - cls.params = { - 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), - 'mode': cls.fixture['mode'], - 'algorithm': cls.fixture['algorithm'], - 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), - } - cls.cipher_params = CipherParams(**cls.params) - cls.cipher = get_cipher(cls.cipher_params) - cls.items = cls.fixture['items'] - - def get_encoded(self, encoded_item): - if encoded_item.get('encoding') == 'base64': - return base64.b64decode(encoded_item['data'].encode('ascii')) - elif encoded_item.get('encoding') == 'json': - return json.loads(encoded_item['data']) - return encoded_item['data'] - - # TM3 - def test_decode(self): - for item in self.items: - assert item['encoded']['name'] == item['encrypted']['name'] - message = Message.from_encoded(item['encrypted'], self.cipher) - assert message.encoding == '' - expected_data = self.get_encoded(item['encoded']) - assert expected_data == message.data - - # TM3 - def test_decode_array(self): - items_encrypted = [item['encrypted'] for item in self.items] - messages = Message.from_encoded_array(items_encrypted, self.cipher) - for i, message in enumerate(messages): - assert message.encoding == '' - expected_data = self.get_encoded(self.items[i]['encoded']) - assert expected_data == message.data - - def test_encode(self): - for item in self.items: - # need to reset iv - self.cipher_params = CipherParams(**self.params) - self.cipher = get_cipher(self.cipher_params) - data = self.get_encoded(item['encoded']) - expected = item['encrypted'] - message = Message(item['encoded']['name'], data) - message.encrypt(self.cipher) - as_dict = message.as_dict() - assert as_dict['data'] == expected['data'] - assert as_dict['encoding'] == expected['encoding'] - - -class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): - fixture_file = 'crypto-data-128.json' - - -class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): - fixture_file = 'crypto-data-256.json' +# import json +# import os +# import logging +# import base64 +# +# import pytest +# +# from ably import AblyException +# from ably.types.message import Message +# from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params +# +# from Crypto import Random +# +# from test.ably.testapp import TestApp +# from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase +# +# log = logging.getLogger(__name__) +# +# +# class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): +# +# async def asyncSetUp(self): +# self.test_vars = await TestApp.get_test_vars() +# self.ably = await TestApp.get_ably_rest() +# self.ably2 = await TestApp.get_ably_rest() +# +# async def asyncTearDown(self): +# await self.ably.close() +# await self.ably2.close() +# +# def per_protocol_setup(self, use_binary_protocol): +# # This will be called every test that vary by protocol for each protocol +# self.ably.options.use_binary_protocol = use_binary_protocol +# self.ably2.options.use_binary_protocol = use_binary_protocol +# self.use_binary_protocol = use_binary_protocol +# +# @dont_vary_protocol +# def test_cbc_channel_cipher(self): +# key = ( +# b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' +# b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') +# +# iv = ( +# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' +# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') +# +# log.debug("KEY_LEN: %d" % len(key)) +# log.debug("IV_LEN: %d" % len(iv)) +# cipher = get_cipher({'key': key, 'iv': iv}) +# +# plaintext = b"The quick brown fox" +# expected_ciphertext = ( +# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' +# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' +# b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' +# b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' +# b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' +# b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') +# +# actual_ciphertext = cipher.encrypt(plaintext) +# +# assert expected_ciphertext == actual_ciphertext +# +# async def test_crypto_publish(self): +# channel_name = self.get_channel_name('persisted:crypto_publish_text') +# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# history = await publish0.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String)" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_publish_256(self): +# rndfile = Random.new() +# key = rndfile.read(32) +# channel_name = 'persisted:crypto_publish_text_256' +# channel_name += '_bin' if self.use_binary_protocol else '_text' +# +# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# history = await publish0.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String)" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_publish_key_mismatch(self): +# channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') +# +# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# with pytest.raises(AblyException) as excinfo: +# await rx_channel.history() +# +# message = excinfo.value.message +# assert 'invalid-padding' == message or "codec can't decode" in message +# +# async def test_crypto_send_unencrypted(self): +# channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') +# publish0 = self.ably.channels[channel_name] +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# history = await rx_channel.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_encrypted_unhandled(self): +# channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') +# key = b'0123456789abcdef' +# data = 'foobar' +# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) +# +# await publish0.publish("publish0", data) +# +# rx_channel = self.ably2.channels[channel_name] +# history = await rx_channel.history() +# message = history.items[0] +# cipher = get_cipher(get_default_params({'key': key})) +# assert cipher.decrypt(message.data).decode() == data +# assert message.encoding == 'utf-8/cipher+aes-128-cbc' +# +# @dont_vary_protocol +# def test_cipher_params(self): +# params = CipherParams(secret_key='0123456789abcdef') +# assert params.algorithm == 'AES' +# assert params.mode == 'CBC' +# assert params.key_length == 128 +# +# params = CipherParams(secret_key='0123456789abcdef' * 2) +# assert params.algorithm == 'AES' +# assert params.mode == 'CBC' +# assert params.key_length == 256 +# +# +# class AbstractTestCryptoWithFixture: +# +# @classmethod +# def setUpClass(cls): +# resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file +# with open(resources_path, 'r') as f: +# cls.fixture = json.loads(f.read()) +# cls.params = { +# 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), +# 'mode': cls.fixture['mode'], +# 'algorithm': cls.fixture['algorithm'], +# 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), +# } +# cls.cipher_params = CipherParams(**cls.params) +# cls.cipher = get_cipher(cls.cipher_params) +# cls.items = cls.fixture['items'] +# +# def get_encoded(self, encoded_item): +# if encoded_item.get('encoding') == 'base64': +# return base64.b64decode(encoded_item['data'].encode('ascii')) +# elif encoded_item.get('encoding') == 'json': +# return json.loads(encoded_item['data']) +# return encoded_item['data'] +# +# # TM3 +# def test_decode(self): +# for item in self.items: +# assert item['encoded']['name'] == item['encrypted']['name'] +# message = Message.from_encoded(item['encrypted'], self.cipher) +# assert message.encoding == '' +# expected_data = self.get_encoded(item['encoded']) +# assert expected_data == message.data +# +# # TM3 +# def test_decode_array(self): +# items_encrypted = [item['encrypted'] for item in self.items] +# messages = Message.from_encoded_array(items_encrypted, self.cipher) +# for i, message in enumerate(messages): +# assert message.encoding == '' +# expected_data = self.get_encoded(self.items[i]['encoded']) +# assert expected_data == message.data +# +# def test_encode(self): +# for item in self.items: +# # need to reset iv +# self.cipher_params = CipherParams(**self.params) +# self.cipher = get_cipher(self.cipher_params) +# data = self.get_encoded(item['encoded']) +# expected = item['encrypted'] +# message = Message(item['encoded']['name'], data) +# message.encrypt(self.cipher) +# as_dict = message.as_dict() +# assert as_dict['data'] == expected['data'] +# assert as_dict['encoding'] == expected['encoding'] +# +# +# class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): +# fixture_file = 'crypto-data-128.json' +# +# +# class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): +# fixture_file = 'crypto-data-256.json' diff --git a/test/ably/sync/rest/encoders_test.py b/test/ably/sync/rest/encoders_test.py new file mode 100644 index 00000000..83d2e852 --- /dev/null +++ b/test/ably/sync/rest/encoders_test.py @@ -0,0 +1,456 @@ +import base64 +import json +import logging +import sys + +import mock +import msgpack + +from ably.sync import CipherParams +from ably.sync.util.crypto import get_cipher +from ably.sync.types.message import Message + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import BaseAsyncTestCase + +if sys.version_info >= (3, 8): + from unittest.mock import Mock +else: + from mock import Mock + +log = logging.getLogger(__name__) + + +class TestTextEncodersNoEncryption(BaseAsyncTestCase): + def setUp(self): + self.ably = TestApp.get_ably_rest(use_binary_protocol=False) + + def tearDown(self): + self.ably.close() + + def test_text_utf8(self): + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', 'foó') + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['data'] == 'foó' + assert not json.loads(kwargs['body']).get('encoding', '') + + def test_str(self): + # This test only makes sense for py2 + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', 'foo') + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['data'] == 'foo' + assert not json.loads(kwargs['body']).get('encoding', '') + + def test_with_binary_type(self): + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', bytearray(b'foo')) + _, kwargs = post_mock.call_args + raw_data = json.loads(kwargs['body'])['data'] + assert base64.b64decode(raw_data.encode('ascii')) == bytearray(b'foo') + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'base64' + + def test_with_bytes_type(self): + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', b'foo') + _, kwargs = post_mock.call_args + raw_data = json.loads(kwargs['body'])['data'] + assert base64.b64decode(raw_data.encode('ascii')) == bytearray(b'foo') + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'base64' + + def test_with_json_dict_data(self): + channel = self.ably.channels["persisted:publish"] + data = {'foó': 'bár'} + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + raw_data = json.loads(json.loads(kwargs['body'])['data']) + assert raw_data == data + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json' + + def test_with_json_list_data(self): + channel = self.ably.channels["persisted:publish"] + data = ['foó', 'bár'] + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + raw_data = json.loads(json.loads(kwargs['body'])['data']) + assert raw_data == data + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json' + + def test_text_utf8_decode(self): + channel = self.ably.channels["persisted:stringdecode"] + + channel.publish('event', 'fóo') + history = channel.history() + message = history.items[0] + assert message.data == 'fóo' + assert isinstance(message.data, str) + assert not message.encoding + + def test_text_str_decode(self): + channel = self.ably.channels["persisted:stringnonutf8decode"] + + channel.publish('event', 'foo') + history = channel.history() + message = history.items[0] + assert message.data == 'foo' + assert isinstance(message.data, str) + assert not message.encoding + + def test_with_binary_type_decode(self): + channel = self.ably.channels["persisted:binarydecode"] + + channel.publish('event', bytearray(b'foob')) + history = channel.history() + message = history.items[0] + assert message.data == bytearray(b'foob') + assert isinstance(message.data, bytearray) + assert not message.encoding + + def test_with_json_dict_data_decode(self): + channel = self.ably.channels["persisted:jsondict"] + data = {'foó': 'bár'} + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + def test_with_json_list_data_decode(self): + channel = self.ably.channels["persisted:jsonarray"] + data = ['foó', 'bár'] + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + def test_decode_with_invalid_encoding(self): + data = 'foó' + encoded = base64.b64encode(data.encode('utf-8')) + decoded_data = Message.decode(encoded, 'foo/bar/utf-8/base64') + assert decoded_data['data'] == data + assert decoded_data['encoding'] == 'foo/bar' + + +class TestTextEncodersEncryption(BaseAsyncTestCase): + def setUp(self): + self.ably = TestApp.get_ably_rest(use_binary_protocol=False) + self.cipher_params = CipherParams(secret_key='keyfordecrypt_16', + algorithm='aes') + + def tearDown(self): + self.ably.close() + + def decrypt(self, payload, options=None): + if options is None: + options = {} + ciphertext = base64.b64decode(payload.encode('ascii')) + cipher = get_cipher({'key': b'keyfordecrypt_16'}) + return cipher.decrypt(ciphertext) + + def test_text_utf8(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', 'fóo') + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc/base64' + data = self.decrypt(json.loads(kwargs['body'])['data']).decode('utf-8') + assert data == 'fóo' + + def test_str(self): + # This test only makes sense for py2 + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', 'foo') + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['data'] == 'foo' + assert not json.loads(kwargs['body']).get('encoding', '') + + def test_with_binary_type(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', bytearray(b'foo')) + _, kwargs = post_mock.call_args + + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'cipher+aes-128-cbc/base64' + data = self.decrypt(json.loads(kwargs['body'])['data']) + assert data == bytearray(b'foo') + assert isinstance(data, bytearray) + + def test_with_json_dict_data(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + data = {'foó': 'bár'} + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' + raw_data = self.decrypt(json.loads(kwargs['body'])['data']).decode('ascii') + assert json.loads(raw_data) == data + + def test_with_json_list_data(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + data = ['foó', 'bár'] + with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' + raw_data = self.decrypt(json.loads(kwargs['body'])['data']).decode('ascii') + assert json.loads(raw_data) == data + + def test_text_utf8_decode(self): + channel = self.ably.channels.get("persisted:enc_stringdecode", + cipher=self.cipher_params) + channel.publish('event', 'foó') + history = channel.history() + message = history.items[0] + assert message.data == 'foó' + assert isinstance(message.data, str) + assert not message.encoding + + def test_with_binary_type_decode(self): + channel = self.ably.channels.get("persisted:enc_binarydecode", + cipher=self.cipher_params) + + channel.publish('event', bytearray(b'foob')) + history = channel.history() + message = history.items[0] + assert message.data == bytearray(b'foob') + assert isinstance(message.data, bytearray) + assert not message.encoding + + def test_with_json_dict_data_decode(self): + channel = self.ably.channels.get("persisted:enc_jsondict", + cipher=self.cipher_params) + data = {'foó': 'bár'} + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + def test_with_json_list_data_decode(self): + channel = self.ably.channels.get("persisted:enc_list", + cipher=self.cipher_params) + data = ['foó', 'bár'] + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + +class TestBinaryEncodersNoEncryption(BaseAsyncTestCase): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def decode(self, data): + return msgpack.unpackb(data) + + def test_text_utf8(self): + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', 'foó') + _, kwargs = post_mock.call_args + assert self.decode(kwargs['body'])['data'] == 'foó' + assert self.decode(kwargs['body']).get('encoding', '').strip('/') == '' + + def test_with_binary_type(self): + channel = self.ably.channels["persisted:publish"] + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', bytearray(b'foo')) + _, kwargs = post_mock.call_args + assert self.decode(kwargs['body'])['data'] == bytearray(b'foo') + assert self.decode(kwargs['body']).get('encoding', '').strip('/') == '' + + def test_with_json_dict_data(self): + channel = self.ably.channels["persisted:publish"] + data = {'foó': 'bár'} + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + raw_data = json.loads(self.decode(kwargs['body'])['data']) + assert raw_data == data + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json' + + def test_with_json_list_data(self): + channel = self.ably.channels["persisted:publish"] + data = ['foó', 'bár'] + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + raw_data = json.loads(self.decode(kwargs['body'])['data']) + assert raw_data == data + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json' + + def test_text_utf8_decode(self): + channel = self.ably.channels["persisted:stringdecode-bin"] + + channel.publish('event', 'fóo') + history = channel.history() + message = history.items[0] + assert message.data == 'fóo' + assert isinstance(message.data, str) + assert not message.encoding + + def test_with_binary_type_decode(self): + channel = self.ably.channels["persisted:binarydecode-bin"] + + channel.publish('event', bytearray(b'foob')) + history = channel.history() + message = history.items[0] + assert message.data == bytearray(b'foob') + assert not message.encoding + + def test_with_json_dict_data_decode(self): + channel = self.ably.channels["persisted:jsondict-bin"] + data = {'foó': 'bár'} + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + def test_with_json_list_data_decode(self): + channel = self.ably.channels["persisted:jsonarray-bin"] + data = ['foó', 'bár'] + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + +class TestBinaryEncodersEncryption(BaseAsyncTestCase): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.cipher_params = CipherParams(secret_key='keyfordecrypt_16', algorithm='aes') + + def tearDown(self): + self.ably.close() + + def decrypt(self, payload, options=None): + if options is None: + options = {} + cipher = get_cipher({'key': b'keyfordecrypt_16'}) + return cipher.decrypt(payload) + + def decode(self, data): + return msgpack.unpackb(data) + + def test_text_utf8(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', 'fóo') + _, kwargs = post_mock.call_args + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc' + data = self.decrypt(self.decode(kwargs['body'])['data']).decode('utf-8') + assert data == 'fóo' + + def test_with_binary_type(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', bytearray(b'foo')) + _, kwargs = post_mock.call_args + + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'cipher+aes-128-cbc' + data = self.decrypt(self.decode(kwargs['body'])['data']) + assert data == bytearray(b'foo') + assert isinstance(data, bytearray) + + def test_with_json_dict_data(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + data = {'foó': 'bár'} + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc' + raw_data = self.decrypt(self.decode(kwargs['body'])['data']).decode('ascii') + assert json.loads(raw_data) == data + + def test_with_json_list_data(self): + channel = self.ably.channels.get("persisted:publish_enc", + cipher=self.cipher_params) + data = ['foó', 'bár'] + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish('event', data) + _, kwargs = post_mock.call_args + assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc' + raw_data = self.decrypt(self.decode(kwargs['body'])['data']).decode('ascii') + assert json.loads(raw_data) == data + + def test_text_utf8_decode(self): + channel = self.ably.channels.get("persisted:enc_stringdecode-bin", + cipher=self.cipher_params) + channel.publish('event', 'foó') + history = channel.history() + message = history.items[0] + assert message.data == 'foó' + assert isinstance(message.data, str) + assert not message.encoding + + def test_with_binary_type_decode(self): + channel = self.ably.channels.get("persisted:enc_binarydecode-bin", + cipher=self.cipher_params) + + channel.publish('event', bytearray(b'foob')) + history = channel.history() + message = history.items[0] + assert message.data == bytearray(b'foob') + assert isinstance(message.data, bytearray) + assert not message.encoding + + def test_with_json_dict_data_decode(self): + channel = self.ably.channels.get("persisted:enc_jsondict-bin", + cipher=self.cipher_params) + data = {'foó': 'bár'} + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding + + def test_with_json_list_data_decode(self): + channel = self.ably.channels.get("persisted:enc_list-bin", + cipher=self.cipher_params) + data = ['foó', 'bár'] + channel.publish('event', data) + history = channel.history() + message = history.items[0] + assert message.data == data + assert not message.encoding diff --git a/test/ably/sync/rest/restauth_test.py b/test/ably/sync/rest/restauth_test.py new file mode 100644 index 00000000..4ca85f45 --- /dev/null +++ b/test/ably/sync/rest/restauth_test.py @@ -0,0 +1,652 @@ +import logging +import sys +import time +import uuid +import base64 + +from urllib.parse import parse_qs +import mock +import pytest +import respx +from httpx import Response, Client + +import ably +from ably.sync import AblyRest +from ably.sync import Auth +from ably.sync import AblyAuthException +from ably.sync.types.tokendetails import TokenDetails + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + +if sys.version_info >= (3, 8): + from unittest.mock import Mock +else: + from mock import Mock + +log = logging.getLogger(__name__) + + +# does not make any request, no need to vary by protocol +class TestAuth(BaseAsyncTestCase): + def setUp(self): + self.test_vars = TestApp.get_test_vars() + + def test_auth_init_key_only(self): + ably = AblyRest(key=self.test_vars["keys"][0]["key_str"]) + assert Auth.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert ably.auth.auth_options.key_name == self.test_vars["keys"][0]['key_name'] + assert ably.auth.auth_options.key_secret == self.test_vars["keys"][0]['key_secret'] + + def test_auth_init_token_only(self): + ably = AblyRest(token="this_is_not_really_a_token") + + assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + + def test_auth_token_details(self): + td = TokenDetails() + ably = AblyRest(token_details=td) + + assert Auth.Method.TOKEN == ably.auth.auth_mechanism + assert ably.auth.token_details is td + + def test_auth_init_with_token_callback(self): + callback_called = [] + + def token_callback(token_params): + callback_called.append(True) + return "this_is_not_really_a_token_request" + + ably = TestApp.get_ably_rest( + key=None, + key_name=self.test_vars["keys"][0]["key_name"], + auth_callback=token_callback) + + try: + ably.stats(None) + except Exception: + pass + + assert callback_called, "Token callback not called" + assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + + def test_auth_init_with_key_and_client_id(self): + ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], client_id='testClientId') + + assert Auth.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert ably.auth.client_id == 'testClientId' + + def test_auth_init_with_token(self): + ably = TestApp.get_ably_rest(key=None, token="this_is_not_really_a_token") + assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + + # RSA11 + def test_request_basic_auth_header(self): + ably = AblyRest(key_secret='foo', key_name='bar') + + with mock.patch.object(Client, 'send') as get_mock: + try: + ably.http.get('/time', skip_auth=False) + except Exception: + pass + request = get_mock.call_args_list[0][0][0] + authorization = request.headers['Authorization'] + assert authorization == 'Basic %s' % base64.b64encode('bar:foo'.encode('ascii')).decode('utf-8') + + # RSA7e2 + def test_request_basic_auth_header_with_client_id(self): + ably = AblyRest(key_secret='foo', key_name='bar', client_id='client_id') + + with mock.patch.object(Client, 'send') as get_mock: + try: + ably.http.get('/time', skip_auth=False) + except Exception: + pass + request = get_mock.call_args_list[0][0][0] + client_id = request.headers['x-ably-clientid'] + assert client_id == base64.b64encode('client_id'.encode('ascii')).decode('utf-8') + + def test_request_token_auth_header(self): + ably = AblyRest(token='not_a_real_token') + + with mock.patch.object(Client, 'send') as get_mock: + try: + ably.http.get('/time', skip_auth=False) + except Exception: + pass + request = get_mock.call_args_list[0][0][0] + authorization = request.headers['Authorization'] + assert authorization == 'Bearer %s' % base64.b64encode('not_a_real_token'.encode('ascii')).decode('utf-8') + + def test_if_cant_authenticate_via_token(self): + with pytest.raises(ValueError): + AblyRest(use_token_auth=True) + + def test_use_auth_token(self): + ably = AblyRest(use_token_auth=True, key=self.test_vars["keys"][0]["key_str"]) + assert ably.auth.auth_mechanism == Auth.Method.TOKEN + + def test_with_client_id(self): + ably = AblyRest(use_token_auth=True, client_id='client_id', key=self.test_vars["keys"][0]["key_str"]) + assert ably.auth.auth_mechanism == Auth.Method.TOKEN + + def test_with_auth_url(self): + ably = AblyRest(auth_url='auth_url') + assert ably.auth.auth_mechanism == Auth.Method.TOKEN + + def test_with_auth_callback(self): + ably = AblyRest(auth_callback=lambda x: x) + assert ably.auth.auth_mechanism == Auth.Method.TOKEN + + def test_with_token(self): + ably = AblyRest(token='a token') + assert ably.auth.auth_mechanism == Auth.Method.TOKEN + + def test_default_ttl_is_1hour(self): + one_hour_in_ms = 60 * 60 * 1000 + assert TokenDetails.DEFAULTS['ttl'] == one_hour_in_ms + + def test_with_auth_method(self): + ably = AblyRest(token='a token', auth_method='POST') + assert ably.auth.auth_options.auth_method == 'POST' + + def test_with_auth_headers(self): + ably = AblyRest(token='a token', auth_headers={'h1': 'v1'}) + assert ably.auth.auth_options.auth_headers == {'h1': 'v1'} + + def test_with_auth_params(self): + ably = AblyRest(token='a token', auth_params={'p': 'v'}) + assert ably.auth.auth_options.auth_params == {'p': 'v'} + + def test_with_default_token_params(self): + ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], + default_token_params={'ttl': 12345}) + assert ably.auth.auth_options.default_token_params == {'ttl': 12345} + + +class TestAuthAuthorize(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.test_vars = TestApp.get_test_vars() + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def test_if_authorize_changes_auth_mechanism_to_token(self): + assert Auth.Method.BASIC == self.ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + + self.ably.auth.authorize() + + assert Auth.Method.TOKEN == self.ably.auth.auth_mechanism, "Authorize should change the Auth method" + + # RSA10a + @dont_vary_protocol + def test_authorize_always_creates_new_token(self): + self.ably.auth.authorize({'capability': {'test': ['publish']}}) + self.ably.channels.test.publish('event', 'data') + + self.ably.auth.authorize({'capability': {'test': ['subscribe']}}) + with pytest.raises(AblyAuthException): + self.ably.channels.test.publish('event', 'data') + + def test_authorize_create_new_token_if_expired(self): + token = self.ably.auth.authorize() + with mock.patch('ably.rest.auth.Auth.token_details_has_expired', + return_value=True): + new_token = self.ably.auth.authorize() + + assert token is not new_token + + def test_authorize_returns_a_token_details(self): + token = self.ably.auth.authorize() + assert isinstance(token, TokenDetails) + + @dont_vary_protocol + def test_authorize_adheres_to_request_token(self): + token_params = {'ttl': 10, 'client_id': 'client_id'} + auth_params = {'auth_url': 'somewhere.com', 'query_time': True} + with mock.patch('ably.sync.rest.auth.Auth.request_token', new_callable=Mock) as request_mock: + self.ably.auth.authorize(token_params, auth_params) + + token_called, auth_called = request_mock.call_args + assert token_called[0] == token_params + + # Authorize may call request_token with some default auth_options. + for arg, value in auth_params.items(): + assert auth_called[arg] == value, "%s called with wrong value: %s" % (arg, value) + + def test_with_token_str_https(self): + token = self.ably.auth.authorize() + token = token.token + ably = TestApp.get_ably_rest(key=None, token=token, tls=True, + use_binary_protocol=self.use_binary_protocol) + ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') + ably.close() + + def test_with_token_str_http(self): + token = self.ably.auth.authorize() + token = token.token + ably = TestApp.get_ably_rest(key=None, token=token, tls=False, + use_binary_protocol=self.use_binary_protocol) + ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') + ably.close() + + def test_if_default_client_id_is_used(self): + ably = TestApp.get_ably_rest(client_id='my_client_id', + use_binary_protocol=self.use_binary_protocol) + token = ably.auth.authorize() + assert token.client_id == 'my_client_id' + ably.close() + + # RSA10j + def test_if_parameters_are_stored_and_used_as_defaults(self): + # Define some parameters + auth_options = dict(self.ably.auth.auth_options.auth_options) + auth_options['auth_headers'] = {'a_headers': 'a_value'} + self.ably.auth.authorize({'ttl': 555}, auth_options) + with mock.patch('ably.sync.rest.auth.Auth.request_token', + wraps=self.ably.auth.request_token) as request_mock: + self.ably.auth.authorize() + + token_called, auth_called = request_mock.call_args + assert token_called[0] == {'ttl': 555} + assert auth_called['auth_headers'] == {'a_headers': 'a_value'} + + # Different parameters, should completely replace the first ones, not merge + auth_options = dict(self.ably.auth.auth_options.auth_options) + auth_options['auth_headers'] = None + self.ably.auth.authorize({}, auth_options) + with mock.patch('ably.sync.rest.auth.Auth.request_token', + wraps=self.ably.auth.request_token) as request_mock: + self.ably.auth.authorize() + + token_called, auth_called = request_mock.call_args + assert token_called[0] == {} + assert auth_called['auth_headers'] is None + + # RSA10g + def test_timestamp_is_not_stored(self): + # authorize once with arbitrary defaults + auth_options = dict(self.ably.auth.auth_options.auth_options) + auth_options['auth_headers'] = {'a_headers': 'a_value'} + token_1 = self.ably.auth.authorize( + {'ttl': 60 * 1000, 'client_id': 'new_id'}, + auth_options) + assert isinstance(token_1, TokenDetails) + + # call authorize again with timestamp set + timestamp = self.ably.time() + with mock.patch('ably.sync.rest.auth.TokenRequest', + wraps=ably.types.tokenrequest.TokenRequest) as tr_mock: + auth_options = dict(self.ably.auth.auth_options.auth_options) + auth_options['auth_headers'] = {'a_headers': 'a_value'} + token_2 = self.ably.auth.authorize( + {'ttl': 60 * 1000, 'client_id': 'new_id', 'timestamp': timestamp}, + auth_options) + assert isinstance(token_2, TokenDetails) + assert token_1 != token_2 + assert tr_mock.call_args[1]['timestamp'] == timestamp + + # call authorize again with no params + with mock.patch('ably.sync.rest.auth.TokenRequest', + wraps=ably.types.tokenrequest.TokenRequest) as tr_mock: + token_4 = self.ably.auth.authorize() + assert isinstance(token_4, TokenDetails) + assert token_2 != token_4 + assert tr_mock.call_args[1]['timestamp'] != timestamp + + def test_client_id_precedence(self): + client_id = uuid.uuid4().hex + overridden_client_id = uuid.uuid4().hex + ably = TestApp.get_ably_rest( + use_binary_protocol=self.use_binary_protocol, + client_id=client_id, + default_token_params={'client_id': overridden_client_id}) + token = ably.auth.authorize() + assert token.client_id == client_id + assert ably.auth.client_id == client_id + + channel = ably.channels[ + self.get_channel_name('test_client_id_precedence')] + channel.publish('test', 'data') + history = channel.history() + assert history.items[0].client_id == client_id + ably.close() + + +class TestRequestToken(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + + def per_protocol_setup(self, use_binary_protocol): + self.use_binary_protocol = use_binary_protocol + + def test_with_key(self): + ably = TestApp.get_ably_rest(use_binary_protocol=self.use_binary_protocol) + + token_details = ably.auth.request_token() + assert isinstance(token_details, TokenDetails) + ably.close() + + ably = TestApp.get_ably_rest(key=None, token_details=token_details, + use_binary_protocol=self.use_binary_protocol) + channel = self.get_channel_name('test_request_token_with_key') + + ably.channels[channel].publish('event', 'foo') + + history = ably.channels[channel].history() + assert history.items[0].data == 'foo' + ably.close() + + @dont_vary_protocol + @respx.mock + def test_with_auth_url_headers_and_params_POST(self): # noqa: N802 + url = 'http://www.example.com' + headers = {'foo': 'bar'} + ably = TestApp.get_ably_rest(key=None, auth_url=url) + + auth_params = {'foo': 'auth', 'spam': 'eggs'} + token_params = {'foo': 'token'} + auth_route = respx.post(url) + + def call_back(request): + assert request.headers['content-type'] == 'application/x-www-form-urlencoded' + assert headers['foo'] == request.headers['foo'] + + # TokenParams has precedence + assert parse_qs(request.content.decode('utf-8')) == {'foo': ['token'], 'spam': ['eggs']} + return Response( + status_code=200, + content="token_string", + headers={ + "Content-Type": "text/plain", + } + ) + + auth_route.side_effect = call_back + token_details = ably.auth.request_token( + token_params=token_params, auth_url=url, auth_headers=headers, + auth_method='POST', auth_params=auth_params) + + assert 1 == auth_route.called + assert isinstance(token_details, TokenDetails) + assert 'token_string' == token_details.token + ably.close() + + @dont_vary_protocol + @respx.mock + def test_with_auth_url_headers_and_params_GET(self): # noqa: N802 + url = 'http://www.example.com' + headers = {'foo': 'bar'} + ably = TestApp.get_ably_rest( + key=None, auth_url=url, + auth_headers={'this': 'will_not_be_used'}, + auth_params={'this': 'will_not_be_used'}) + + auth_params = {'foo': 'auth', 'spam': 'eggs'} + token_params = {'foo': 'token'} + auth_route = respx.get(url, params={'foo': ['token'], 'spam': ['eggs']}) + + def call_back(request): + assert request.headers['foo'] == 'bar' + assert 'this' not in request.headers + assert not request.content + + return Response( + status_code=200, + json={'issued': 1, 'token': 'another_token_string'} + ) + auth_route.side_effect = call_back + token_details = ably.auth.request_token( + token_params=token_params, auth_url=url, auth_headers=headers, + auth_params=auth_params) + assert 'another_token_string' == token_details.token + ably.close() + + @dont_vary_protocol + def test_with_callback(self): + called_token_params = {'ttl': '3600000'} + + def callback(token_params): + assert token_params == called_token_params + return 'token_string' + + ably = TestApp.get_ably_rest(key=None, auth_callback=callback) + + token_details = ably.auth.request_token( + token_params=called_token_params, auth_callback=callback) + assert isinstance(token_details, TokenDetails) + assert 'token_string' == token_details.token + + def callback(token_params): + assert token_params == called_token_params + return TokenDetails(token='another_token_string') + + token_details = ably.auth.request_token( + token_params=called_token_params, auth_callback=callback) + assert 'another_token_string' == token_details.token + ably.close() + + @dont_vary_protocol + @respx.mock + def test_when_auth_url_has_query_string(self): + url = 'http://www.example.com?with=query' + headers = {'foo': 'bar'} + ably = TestApp.get_ably_rest(key=None, auth_url=url) + auth_route = respx.get('http://www.example.com', params={'with': 'query', 'spam': 'eggs'}).mock( + return_value=Response(status_code=200, content='token_string', headers={"Content-Type": "text/plain"})) + ably.auth.request_token(auth_url=url, + auth_headers=headers, + auth_params={'spam': 'eggs'}) + assert auth_route.called + ably.close() + + @dont_vary_protocol + def test_client_id_null_for_anonymous_auth(self): + ably = TestApp.get_ably_rest( + key=None, + key_name=self.test_vars["keys"][0]["key_name"], + key_secret=self.test_vars["keys"][0]["key_secret"]) + token = ably.auth.authorize() + + assert isinstance(token, TokenDetails) + assert token.client_id is None + assert ably.auth.client_id is None + ably.close() + + @dont_vary_protocol + def test_client_id_null_until_auth(self): + client_id = uuid.uuid4().hex + token_ably = TestApp.get_ably_rest( + default_token_params={'client_id': client_id}) + # before auth, client_id is None + assert token_ably.auth.client_id is None + + token = token_ably.auth.authorize() + assert isinstance(token, TokenDetails) + + # after auth, client_id is defined + assert token.client_id == client_id + assert token_ably.auth.client_id == client_id + token_ably.close() + + +class TestRenewToken(BaseAsyncTestCase): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.host = 'fake-host.ably.io' + self.ably = TestApp.get_ably_rest(use_binary_protocol=False, rest_host=self.host) + # with headers + self.publish_attempts = 0 + self.channel = uuid.uuid4().hex + tokens = ['a_token', 'another_token'] + headers = {'Content-Type': 'application/json'} + self.mocked_api = respx.mock(base_url='https://{}'.format(self.host)) + self.request_token_route = self.mocked_api.post( + "/keys/{}/requestToken".format(self.test_vars["keys"][0]['key_name']), + name="request_token_route") + self.request_token_route.return_value = Response( + status_code=200, + headers=headers, + json={ + 'token': tokens[self.request_token_route.call_count - 1], + 'expires': (time.time() + 60) * 1000 + }, + ) + + def call_back(request): + self.publish_attempts += 1 + if self.publish_attempts in [1, 3]: + return Response( + status_code=201, + headers=headers, + json=[], + ) + return Response( + status_code=401, + headers=headers, + json={ + 'error': {'message': 'Authentication failure', 'statusCode': 401, 'code': 40140} + }, + ) + + self.publish_attempt_route = self.mocked_api.post("/channels/{}/messages".format(self.channel), + name="publish_attempt_route") + self.publish_attempt_route.side_effect = call_back + self.mocked_api.start() + + def tearDown(self): + # We need to have quiet here in order to do not have check if all endpoints were called + self.mocked_api.stop(quiet=True) + self.mocked_api.reset() + self.ably.close() + + # RSA4b + def test_when_renewable(self): + self.ably.auth.authorize() + self.ably.channels[self.channel].publish('evt', 'msg') + assert self.mocked_api["request_token_route"].call_count == 1 + assert self.publish_attempts == 1 + + # Triggers an authentication 401 failure which should automatically request a new token + self.ably.channels[self.channel].publish('evt', 'msg') + assert self.mocked_api["request_token_route"].call_count == 2 + assert self.publish_attempts == 3 + + # RSA4a + def test_when_not_renewable(self): + self.ably.close() + + self.ably = TestApp.get_ably_rest( + key=None, + rest_host=self.host, + token='token ID cannot be used to create a new token', + use_binary_protocol=False) + self.ably.channels[self.channel].publish('evt', 'msg') + assert self.publish_attempts == 1 + + publish = self.ably.channels[self.channel].publish + + match = "Need a new token but auth_options does not include a way to request one" + with pytest.raises(AblyAuthException, match=match): + publish('evt', 'msg') + + assert not self.mocked_api["request_token_route"].called + + # RSA4a + def test_when_not_renewable_with_token_details(self): + token_details = TokenDetails(token='a_dummy_token') + self.ably = TestApp.get_ably_rest( + key=None, + rest_host=self.host, + token_details=token_details, + use_binary_protocol=False) + self.ably.channels[self.channel].publish('evt', 'msg') + assert self.mocked_api["publish_attempt_route"].call_count == 1 + + publish = self.ably.channels[self.channel].publish + + match = "Need a new token but auth_options does not include a way to request one" + with pytest.raises(AblyAuthException, match=match): + publish('evt', 'msg') + + assert not self.mocked_api["request_token_route"].called + + +class TestRenewExpiredToken(BaseAsyncTestCase): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.publish_attempts = 0 + self.channel = uuid.uuid4().hex + + self.host = 'fake-host.ably.io' + key = self.test_vars["keys"][0]['key_name'] + headers = {'Content-Type': 'application/json'} + + self.mocked_api = respx.mock(base_url='https://{}'.format(self.host)) + self.request_token_route = self.mocked_api.post("/keys/{}/requestToken".format(key), + name="request_token_route") + self.request_token_route.return_value = Response( + status_code=200, + headers=headers, + json={ + 'token': 'a_token', + 'expires': int(time.time() * 1000), # Always expires + } + ) + self.publish_message_route = self.mocked_api.post("/channels/{}/messages".format(self.channel), + name="publish_message_route") + self.time_route = self.mocked_api.get("/time", name="time_route") + self.time_route.return_value = Response( + status_code=200, + headers=headers, + json=[int(time.time() * 1000)] + ) + + def cb_publish(request): + self.publish_attempts += 1 + if self.publish_fail: + self.publish_fail = False + return Response( + status_code=401, + json={ + 'error': {'message': 'Authentication failure', 'statusCode': 401, 'code': 40140} + } + ) + return Response( + status_code=201, + json='[]' + ) + + self.publish_message_route.side_effect = cb_publish + self.mocked_api.start() + + def tearDown(self): + self.mocked_api.stop(quiet=True) + self.mocked_api.reset() + + # RSA4b1 + def test_query_time_false(self): + ably = TestApp.get_ably_rest(rest_host=self.host) + ably.auth.authorize() + self.publish_fail = True + ably.channels[self.channel].publish('evt', 'msg') + assert self.publish_attempts == 2 + ably.close() + + # RSA4b1 + def test_query_time_true(self): + ably = TestApp.get_ably_rest(query_time=True, rest_host=self.host) + ably.auth.authorize() + self.publish_fail = False + ably.channels[self.channel].publish('evt', 'msg') + assert self.publish_attempts == 1 + ably.close() diff --git a/test/ably/sync/rest/restcapability_test.py b/test/ably/sync/rest/restcapability_test.py new file mode 100644 index 00000000..486f148c --- /dev/null +++ b/test/ably/sync/rest/restcapability_test.py @@ -0,0 +1,243 @@ +import pytest + +from ably.sync.types.capability import Capability +from ably.sync.util.exceptions import AblyException + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + + +class TestRestCapability(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + def test_blanket_intersection_with_key(self): + key = self.test_vars['keys'][1] + token_details = self.ably.auth.request_token(key_name=key['key_name'], + key_secret=key['key_secret']) + expected_capability = Capability(key["capability"]) + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability." + + def test_equal_intersection_with_key(self): + key = self.test_vars['keys'][1] + + token_details = self.ably.auth.request_token( + key_name=key['key_name'], + key_secret=key['key_secret'], + token_params={'capability': key['capability']}) + + expected_capability = Capability(key["capability"]) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + @dont_vary_protocol + def test_empty_ops_intersection(self): + key = self.test_vars['keys'][1] + with pytest.raises(AblyException): + self.ably.auth.request_token( + key_name=key['key_name'], + key_secret=key['key_secret'], + token_params={'capability': {'testchannel': ['subscribe']}}) + + @dont_vary_protocol + def test_empty_paths_intersection(self): + key = self.test_vars['keys'][1] + with pytest.raises(AblyException): + self.ably.auth.request_token( + key_name=key['key_name'], + key_secret=key['key_secret'], + token_params={'capability': {"testchannelx": ["publish"]}}) + + def test_non_empty_ops_intersection(self): + key = self.test_vars['keys'][4] + + token_params = {"capability": { + "channel2": ["presence", "subscribe"] + }} + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + } + + expected_capability = Capability({ + "channel2": ["subscribe"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_non_empty_paths_intersection(self): + key = self.test_vars['keys'][4] + token_params = { + "capability": { + "channel2": ["presence", "subscribe"], + "channelx": ["presence", "subscribe"], + } + } + kwargs = { + "key_name": key["key_name"], + + "key_secret": key["key_secret"] + } + + expected_capability = Capability({ + "channel2": ["subscribe"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_wildcard_ops_intersection(self): + key = self.test_vars['keys'][4] + + token_params = { + "capability": { + "channel2": ["*"], + }, + } + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + } + + expected_capability = Capability({ + "channel2": ["subscribe", "publish"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_wildcard_ops_intersection_2(self): + key = self.test_vars['keys'][4] + + token_params = { + "capability": { + "channel6": ["publish", "subscribe"], + }, + } + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + } + + expected_capability = Capability({ + "channel6": ["subscribe", "publish"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_wildcard_resources_intersection(self): + key = self.test_vars['keys'][2] + + token_params = { + "capability": { + "cansubscribe": ["subscribe"], + }, + } + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + } + + expected_capability = Capability({ + "cansubscribe": ["subscribe"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_wildcard_resources_intersection_2(self): + key = self.test_vars['keys'][2] + + token_params = { + "capability": { + "cansubscribe:check": ["subscribe"], + }, + } + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + } + + expected_capability = Capability({ + "cansubscribe:check": ["subscribe"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + def test_wildcard_resources_intersection_3(self): + key = self.test_vars['keys'][2] + + token_params = { + "capability": { + "cansubscribe:*": ["subscribe"], + }, + } + kwargs = { + "key_name": key["key_name"], + "key_secret": key["key_secret"], + + } + + expected_capability = Capability({ + "cansubscribe:*": ["subscribe"] + }) + + token_details = self.ably.auth.request_token(token_params, **kwargs) + + assert token_details.token is not None, "Expected token" + assert expected_capability == token_details.capability, "Unexpected capability" + + @dont_vary_protocol + def test_invalid_capabilities(self): + with pytest.raises(AblyException) as excinfo: + self.ably.auth.request_token( + token_params={'capability': {"channel0": ["publish_"]}}) + + the_exception = excinfo.value + assert 400 == the_exception.status_code + assert 40000 == the_exception.code + + @dont_vary_protocol + def test_invalid_capabilities_2(self): + with pytest.raises(AblyException) as excinfo: + self.ably.auth.request_token( + token_params={'capability': {"channel0": ["*", "publish"]}}) + + the_exception = excinfo.value + assert 400 == the_exception.status_code + assert 40000 == the_exception.code + + @dont_vary_protocol + def test_invalid_capabilities_3(self): + with pytest.raises(AblyException) as excinfo: + self.ably.auth.request_token( + token_params={'capability': {"channel0": []}}) + + the_exception = excinfo.value + assert 400 == the_exception.status_code + assert 40000 == the_exception.code diff --git a/test/ably/sync/rest/restchannelhistory_test.py b/test/ably/sync/rest/restchannelhistory_test.py new file mode 100644 index 00000000..3c82fcc8 --- /dev/null +++ b/test/ably/sync/rest/restchannelhistory_test.py @@ -0,0 +1,332 @@ +import logging +import pytest +import respx + +from ably.sync import AblyException +from ably.sync.http.paginatedresult import PaginatedResult + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +class TestRestChannelHistory(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest(fallback_hosts=[]) + self.test_vars = TestApp.get_test_vars() + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + def test_channel_history_types(self): + history0 = self.get_channel('persisted:channelhistory_types') + + history0.publish('history0', 'This is a string message payload') + history0.publish('history1', b'This is a byte[] message payload') + history0.publish('history2', {'test': 'This is a JSONObject message payload'}) + history0.publish('history3', ['This is a JSONArray message payload']) + + history = history0.history() + assert isinstance(history, PaginatedResult) + messages = history.items + assert messages is not None, "Expected non-None messages" + assert 4 == len(messages), "Expected 4 messages" + + message_contents = {m.name: m for m in messages} + assert "This is a string message payload" == message_contents["history0"].data, \ + "Expect history0 to be expected String)" + assert b"This is a byte[] message payload" == message_contents["history1"].data, \ + "Expect history1 to be expected byte[]" + assert {"test": "This is a JSONObject message payload"} == message_contents["history2"].data, \ + "Expect history2 to be expected JSONObject" + assert ["This is a JSONArray message payload"] == message_contents["history3"].data, \ + "Expect history3 to be expected JSONObject" + + expected_message_history = [ + message_contents['history3'], + message_contents['history2'], + message_contents['history1'], + message_contents['history0'], + ] + assert expected_message_history == messages, "Expect messages in reverse order" + + def test_channel_history_multi_50_forwards(self): + history0 = self.get_channel('persisted:channelhistory_multi_50_f') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='forwards') + assert history is not None + messages = history.items + assert len(messages) == 50, "Expected 50 messages" + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(50)] + assert messages == expected_messages, 'Expect messages in forward order' + + def test_channel_history_multi_50_backwards(self): + history0 = self.get_channel('persisted:channelhistory_multi_50_b') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='backwards') + assert history is not None + messages = history.items + assert 50 == len(messages), "Expected 50 messages" + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(49, -1, -1)] + assert expected_messages == messages, 'Expect messages in reverse order' + + def history_mock_url(self, channel_name): + kwargs = { + 'scheme': 'https' if self.test_vars['tls'] else 'http', + 'host': self.test_vars['host'], + 'channel_name': channel_name + } + port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] + if port == 80: + kwargs['port_sufix'] = '' + else: + kwargs['port_sufix'] = ':' + str(port) + url = '{scheme}://{host}{port_sufix}/channels/{channel_name}/messages' + return url.format(**kwargs) + + @respx.mock + @dont_vary_protocol + def test_channel_history_default_limit(self): + self.per_protocol_setup(True) + channel = self.ably.channels['persisted:channelhistory_limit'] + url = self.history_mock_url('persisted:channelhistory_limit') + self.respx_add_empty_msg_pack(url) + channel.history() + assert 'limit' not in respx.calls[0].request.url.params.keys() + + @respx.mock + @dont_vary_protocol + def test_channel_history_with_limits(self): + self.per_protocol_setup(True) + channel = self.ably.channels['persisted:channelhistory_limit'] + url = self.history_mock_url('persisted:channelhistory_limit') + self.respx_add_empty_msg_pack(url) + + channel.history(limit=500) + assert '500' in respx.calls[0].request.url.params.get('limit') + + channel.history(limit=1000) + assert '1000' in respx.calls[1].request.url.params.get('limit') + + @dont_vary_protocol + def test_channel_history_max_limit_is_1000(self): + channel = self.ably.channels['persisted:channelhistory_limit'] + with pytest.raises(AblyException): + channel.history(limit=1001) + + def test_channel_history_limit_forwards(self): + history0 = self.get_channel('persisted:channelhistory_limit_f') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='forwards', limit=25) + assert history is not None + messages = history.items + assert len(messages) == 25, "Expected 25 messages" + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(25)] + assert messages == expected_messages, 'Expect messages in forward order' + + def test_channel_history_limit_backwards(self): + history0 = self.get_channel('persisted:channelhistory_limit_b') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='backwards', limit=25) + assert history is not None + messages = history.items + assert len(messages) == 25, "Expected 25 messages" + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(49, 24, -1)] + assert messages == expected_messages, 'Expect messages in forward order' + + def test_channel_history_time_forwards(self): + history0 = self.get_channel('persisted:channelhistory_time_f') + + for i in range(20): + history0.publish('history%d' % i, str(i)) + + interval_start = self.ably.time() + + for i in range(20, 40): + history0.publish('history%d' % i, str(i)) + + interval_end = self.ably.time() + + for i in range(40, 60): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='forwards', start=interval_start, + end=interval_end) + + messages = history.items + assert 20 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(20, 40)] + assert expected_messages == messages, 'Expect messages in forward order' + + def test_channel_history_time_backwards(self): + history0 = self.get_channel('persisted:channelhistory_time_b') + + for i in range(20): + history0.publish('history%d' % i, str(i)) + + interval_start = self.ably.time() + + for i in range(20, 40): + history0.publish('history%d' % i, str(i)) + + interval_end = self.ably.time() + + for i in range(40, 60): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='backwards', start=interval_start, + end=interval_end) + + messages = history.items + assert 20 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(39, 19, -1)] + assert expected_messages, messages == 'Expect messages in reverse order' + + def test_channel_history_paginate_forwards(self): + history0 = self.get_channel('persisted:channelhistory_paginate_f') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='forwards', limit=10) + messages = history.items + + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(10, 20)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(20, 30)] + assert expected_messages == messages, 'Expected 10 messages' + + def test_channel_history_paginate_backwards(self): + history0 = self.get_channel('persisted:channelhistory_paginate_b') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='backwards', limit=10) + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(39, 29, -1)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(29, 19, -1)] + assert expected_messages == messages, 'Expected 10 messages' + + def test_channel_history_paginate_forwards_first(self): + history0 = self.get_channel('persisted:channelhistory_paginate_first_f') + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='forwards', limit=10) + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(10, 20)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.first() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] + assert expected_messages == messages, 'Expected 10 messages' + + def test_channel_history_paginate_backwards_rel_first(self): + history0 = self.get_channel('persisted:channelhistory_paginate_first_b') + + for i in range(50): + history0.publish('history%d' % i, str(i)) + + history = history0.history(direction='backwards', limit=10) + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.next() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(39, 29, -1)] + assert expected_messages == messages, 'Expected 10 messages' + + history = history.first() + messages = history.items + assert 10 == len(messages) + + message_contents = {m.name: m for m in messages} + expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] + assert expected_messages == messages, 'Expected 10 messages' diff --git a/test/ably/sync/rest/restchannelpublish_test.py b/test/ably/sync/rest/restchannelpublish_test.py new file mode 100644 index 00000000..a3c1ebcb --- /dev/null +++ b/test/ably/sync/rest/restchannelpublish_test.py @@ -0,0 +1,568 @@ +import base64 +import binascii +import json +import logging +import os +import uuid + +import httpx +import mock +import msgpack +import pytest + +from ably.sync import api_version +from ably.sync import AblyException, IncompatibleClientIdException +from ably.sync.rest.auth import Auth +from ably.sync.types.message import Message +from ably.sync.types.tokendetails import TokenDetails +from ably.sync.util import case + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +# Ignore library warning regarding client_id +@pytest.mark.filterwarnings('ignore::DeprecationWarning') +class TestRestChannelPublish(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.ably = TestApp.get_ably_rest() + self.client_id = uuid.uuid4().hex + self.ably_with_client_id = TestApp.get_ably_rest(client_id=self.client_id, use_token_auth=True) + + def tearDown(self): + self.ably.close() + self.ably_with_client_id.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.ably_with_client_id.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def test_publish_various_datatypes_text(self): + publish0 = self.ably.channels[ + self.get_channel_name('persisted:publish0')] + + publish0.publish("publish0", "This is a string message payload") + publish0.publish("publish1", b"This is a byte[] message payload") + publish0.publish("publish2", {"test": "This is a JSONObject message payload"}) + publish0.publish("publish3", ["This is a JSONArray message payload"]) + + # Get the history for this channel + history = publish0.history() + messages = history.items + assert messages is not None, "Expected non-None messages" + assert len(messages) == 4, "Expected 4 messages" + + message_contents = dict((m.name, m.data) for m in messages) + log.debug("message_contents: %s" % str(message_contents)) + + assert message_contents["publish0"] == "This is a string message payload", \ + "Expect publish0 to be expected String)" + + assert message_contents["publish1"] == b"This is a byte[] message payload", \ + "Expect publish1 to be expected byte[]. Actual: %s" % str(message_contents['publish1']) + + assert message_contents["publish2"] == {"test": "This is a JSONObject message payload"}, \ + "Expect publish2 to be expected JSONObject" + + assert message_contents["publish3"] == ["This is a JSONArray message payload"], \ + "Expect publish3 to be expected JSONObject" + + @dont_vary_protocol + def test_unsupported_payload_must_raise_exception(self): + channel = self.ably.channels["persisted:publish0"] + for data in [1, 1.1, True]: + with pytest.raises(AblyException): + channel.publish('event', data) + + def test_publish_message_list(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:message_list_channel')] + + expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] + + channel.publish(messages=expected_messages) + + # Get the history for this channel + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == len(expected_messages), "Expected 3 messages" + + for m, expected_m in zip(messages, reversed(expected_messages)): + assert m.name == expected_m.name + assert m.data == expected_m.data + + def test_message_list_generate_one_request(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:message_list_channel_one_request')] + + expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish(messages=expected_messages) + assert post_mock.call_count == 1 + + if self.use_binary_protocol: + messages = msgpack.unpackb(post_mock.call_args[1]['body']) + else: + messages = json.loads(post_mock.call_args[1]['body']) + + for i, message in enumerate(messages): + assert message['name'] == 'name-' + str(i) + assert message['data'] == str(i) + + def test_publish_error(self): + ably = TestApp.get_ably_rest(use_binary_protocol=self.use_binary_protocol) + ably.auth.authorize( + token_params={'capability': {"only_subscribe": ["subscribe"]}}) + + with pytest.raises(AblyException) as excinfo: + ably.channels["only_subscribe"].publish() + + assert 401 == excinfo.value.status_code + assert 40160 == excinfo.value.code + ably.close() + + def test_publish_message_null_name(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:message_null_name_channel')] + + data = "String message" + channel.publish(name=None, data=data) + + # Get the history for this channel + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + assert messages[0].name is None + assert messages[0].data == data + + def test_publish_message_null_data(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:message_null_data_channel')] + + name = "Test name" + channel.publish(name=name, data=None) + + # Get the history for this channel + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + + assert messages[0].name == name + assert messages[0].data is None + + def test_publish_message_null_name_and_data(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:null_name_and_data_channel')] + + channel.publish(name=None, data=None) + channel.publish() + + # Get the history for this channel + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 2, "Expected 2 messages" + + for m in messages: + assert m.name is None + assert m.data is None + + def test_publish_message_null_name_and_data_keys_arent_sent(self): + channel = self.ably.channels[ + self.get_channel_name('persisted:null_name_and_data_keys_arent_sent_channel')] + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish(name=None, data=None) + + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + + assert post_mock.call_count == 1 + + if self.use_binary_protocol: + posted_body = msgpack.unpackb(post_mock.call_args[1]['body']) + else: + posted_body = json.loads(post_mock.call_args[1]['body']) + + assert 'name' not in posted_body + assert 'data' not in posted_body + + def test_message_attr(self): + publish0 = self.ably.channels[ + self.get_channel_name('persisted:publish_message_attr')] + + messages = [Message('publish', + {"test": "This is a JSONObject message payload"}, + client_id='client_id')] + publish0.publish(messages=messages) + + # Get the history for this channel + history = publish0.history() + message = history.items[0] + assert isinstance(message, Message) + assert message.id + assert message.name + assert message.data == {'test': 'This is a JSONObject message payload'} + assert message.encoding == '' + assert message.client_id == 'client_id' + assert isinstance(message.timestamp, int) + + def test_token_is_bound_to_options_client_id_after_publish(self): + # null before publish + assert self.ably_with_client_id.auth.token_details is None + + # created after message publish and will have client_id + channel = self.ably_with_client_id.channels[ + self.get_channel_name('persisted:restricted_to_client_id')] + channel.publish(name='publish', data='test') + + # defined after publish + assert isinstance(self.ably_with_client_id.auth.token_details, TokenDetails) + assert self.ably_with_client_id.auth.token_details.client_id == self.client_id + assert self.ably_with_client_id.auth.auth_mechanism == Auth.Method.TOKEN + history = channel.history() + assert history.items[0].client_id == self.client_id + + def test_publish_message_without_client_id_on_identified_client(self): + channel = self.ably_with_client_id.channels[ + self.get_channel_name('persisted:no_client_id_identified_client')] + + with mock.patch('ably.rest.rest.Http.post', + wraps=channel.ably.http.post) as post_mock: + channel.publish(name='publish', data='test') + + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + + assert post_mock.call_count == 2 + + if self.use_binary_protocol: + posted_body = msgpack.unpackb( + post_mock.mock_calls[0][2]['body']) + else: + posted_body = json.loads( + post_mock.mock_calls[0][2]['body']) + + assert 'client_id' not in posted_body + + # Get the history for this channel + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + + assert messages[0].client_id == self.ably_with_client_id.client_id + + def test_publish_message_with_client_id_on_identified_client(self): + # works if same + channel = self.ably_with_client_id.channels[ + self.get_channel_name('persisted:with_client_id_identified_client')] + message = Message(name='publish', data='test', client_id=self.ably_with_client_id.client_id) + channel.publish(message) + + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 1, "Expected 1 message" + + assert messages[0].client_id == self.ably_with_client_id.client_id + + message = Message(name='publish', data='test', client_id='invalid') + # fails if different + with pytest.raises(IncompatibleClientIdException): + channel.publish(message) + + def test_publish_message_with_wrong_client_id_on_implicit_identified_client(self): + new_token = self.ably.auth.authorize(token_params={'client_id': uuid.uuid4().hex}) + new_ably = TestApp.get_ably_rest(key=None, + token=new_token.token, + use_binary_protocol=self.use_binary_protocol) + + channel = new_ably.channels[ + self.get_channel_name('persisted:wrong_client_id_implicit_client')] + + message = Message(name='publish', data='test', client_id='invalid') + with pytest.raises(AblyException) as excinfo: + channel.publish(message) + + assert 400 == excinfo.value.status_code + assert 40012 == excinfo.value.code + new_ably.close() + + # RSA15b + def test_wildcard_client_id_can_publish_as_others(self): + wildcard_token_details = self.ably.auth.request_token({'client_id': '*'}) + wildcard_ably = TestApp.get_ably_rest( + key=None, + token_details=wildcard_token_details, + use_binary_protocol=self.use_binary_protocol) + + assert wildcard_ably.auth.client_id == '*' + channel = wildcard_ably.channels[ + self.get_channel_name('persisted:wildcard_client_id')] + channel.publish(name='publish1', data='no client_id') + some_client_id = uuid.uuid4().hex + message = Message(name='publish2', data='some client_id', client_id=some_client_id) + channel.publish(message) + + history = channel.history() + messages = history.items + + assert messages is not None, "Expected non-None messages" + assert len(messages) == 2, "Expected 2 messages" + + assert messages[0].client_id == some_client_id + assert messages[1].client_id is None + + wildcard_ably.close() + + # TM2h + @dont_vary_protocol + def test_invalid_connection_key(self): + channel = self.ably.channels["persisted:invalid_connection_key"] + message = Message(data='payload', connection_key='should.be.wrong') + with pytest.raises(AblyException) as excinfo: + channel.publish(messages=[message]) + + assert 400 == excinfo.value.status_code + assert 40006 == excinfo.value.code + + # TM2i, RSL6a2, RSL1h + def test_publish_extras(self): + channel = self.ably.channels[ + self.get_channel_name('canpublish:extras_channel')] + extras = { + 'push': { + 'notification': {"title": "Testing"}, + } + } + message = Message(name='test-name', data='test-data', extras=extras) + channel.publish(message) + + # Get the history for this channel + history = channel.history() + message = history.items[0] + assert message.name == 'test-name' + assert message.data == 'test-data' + assert message.extras == extras + + # RSL6a1 + def test_interoperability(self): + name = self.get_channel_name('persisted:interoperability_channel') + channel = self.ably.channels[name] + + url = 'https://%s/channels/%s/messages' % (self.test_vars["host"], name) + key = self.test_vars['keys'][0] + auth = (key['key_name'], key['key_secret']) + + type_mapping = { + 'string': str, + 'jsonObject': dict, + 'jsonArray': list, + 'binary': bytearray, + } + + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + path = os.path.join(root_dir, 'submodules', 'test-resources', 'messages-encoding.json') + with open(path) as f: + data = json.load(f) + for input_msg in data['messages']: + data = input_msg['data'] + encoding = input_msg['encoding'] + expected_type = input_msg['expectedType'] + if expected_type == 'binary': + expected_value = input_msg.get('expectedHexValue') + expected_value = expected_value.encode('ascii') + expected_value = binascii.a2b_hex(expected_value) + else: + expected_value = input_msg.get('expectedValue') + + # 1) + channel.publish(data=expected_value) + with httpx.Client(http2=True) as client: + r = client.get(url, auth=auth) + item = r.json()[0] + assert item.get('encoding') == encoding + if encoding == 'json': + assert json.loads(item['data']) == json.loads(data) + else: + assert item['data'] == data + + # 2) + channel.publish(messages=[Message(data=data, encoding=encoding)]) + history = channel.history() + message = history.items[0] + assert message.data == expected_value + assert type(message.data) == type_mapping[expected_type] + + # https://github.com/ably/ably-python/issues/130 + def test_publish_slash(self): + channel = self.ably.channels.get(self.get_channel_name('persisted:widgets/')) + name, data = 'Name', 'Data' + channel.publish(name, data) + history = channel.history() + assert len(history.items) == 1 + assert history.items[0].name == name + assert history.items[0].data == data + + # RSL1l + @dont_vary_protocol + def test_publish_params(self): + channel = self.ably.channels.get(self.get_channel_name()) + + message = Message('name', 'data') + with pytest.raises(AblyException) as excinfo: + channel.publish(message, {'_forceNack': True}) + + assert 400 == excinfo.value.status_code + assert 40099 == excinfo.value.code + + +class TestRestChannelPublishIdempotent(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.ably_idempotent = TestApp.get_ably_rest(idempotent_rest_publishing=True) + + def tearDown(self): + self.ably.close() + self.ably_idempotent.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + # TO3n + @dont_vary_protocol + def test_idempotent_rest_publishing(self): + # Test default value + if api_version < '1.2': + assert self.ably.options.idempotent_rest_publishing is False + else: + assert self.ably.options.idempotent_rest_publishing is True + + # Test setting value explicitly + ably = TestApp.get_ably_rest(idempotent_rest_publishing=True) + assert ably.options.idempotent_rest_publishing is True + ably.close() + + ably = TestApp.get_ably_rest(idempotent_rest_publishing=False) + assert ably.options.idempotent_rest_publishing is False + ably.close() + + # RSL1j + @dont_vary_protocol + def test_message_serialization(self): + channel = self.get_channel() + + data = { + 'name': 'name', + 'data': 'data', + 'client_id': 'client_id', + 'extras': {}, + 'id': 'foobar', + } + message = Message(**data) + request_body = channel._Channel__publish_request_body(messages=[message]) + input_keys = set(case.snake_to_camel(x) for x in data.keys()) + assert input_keys - set(request_body) == set() + + # RSL1k1 + @dont_vary_protocol + def test_idempotent_library_generated(self): + channel = self.ably_idempotent.channels[self.get_channel_name()] + + message = Message('name', 'data') + request_body = channel._Channel__publish_request_body(messages=[message]) + base_id, serial = request_body['id'].split(':') + assert len(base64.b64decode(base_id)) >= 9 + assert serial == '0' + + # RSL1k2 + @dont_vary_protocol + def test_idempotent_client_supplied(self): + channel = self.ably_idempotent.channels[self.get_channel_name()] + + message = Message('name', 'data', id='foobar') + request_body = channel._Channel__publish_request_body(messages=[message]) + assert request_body['id'] == 'foobar' + + # RSL1k3 + @dont_vary_protocol + def test_idempotent_mixed_ids(self): + channel = self.ably_idempotent.channels[self.get_channel_name()] + + messages = [ + Message('name', 'data', id='foobar'), + Message('name', 'data'), + ] + request_body = channel._Channel__publish_request_body(messages=messages) + assert request_body[0]['id'] == 'foobar' + assert 'id' not in request_body[1] + + def get_ably_rest(self, *args, **kwargs): + kwargs['use_binary_protocol'] = self.use_binary_protocol + return TestApp.get_ably_rest(*args, **kwargs) + + # RSL1k4 + def test_idempotent_library_generated_retry(self): + test_vars = TestApp.get_test_vars() + ably = self.get_ably_rest(idempotent_rest_publishing=True, fallback_hosts=[test_vars["host"]] * 3) + channel = ably.channels[self.get_channel_name()] + + state = {'failures': 0} + client = httpx.Client(http2=True) + send = client.send + + def side_effect(*args, **kwargs): + x = send(args[1]) + if state['failures'] < 2: + state['failures'] += 1 + raise Exception('faked exception') + return x + + messages = [Message('name1', 'data1')] + with mock.patch('httpx.AsyncClient.send', side_effect=side_effect, autospec=True): + channel.publish(messages=messages) + + assert state['failures'] == 2 + history = channel.history() + assert len(history.items) == 1 + client.close() + ably.close() + + # RSL1k5 + def test_idempotent_client_supplied_publish(self): + ably = self.get_ably_rest(idempotent_rest_publishing=True) + channel = ably.channels[self.get_channel_name()] + + messages = [Message('name1', 'data1', id='foobar')] + channel.publish(messages=messages) + channel.publish(messages=messages) + channel.publish(messages=messages) + history = channel.history() + assert len(history.items) == 1 + ably.close() diff --git a/test/ably/sync/rest/restchannels_test.py b/test/ably/sync/rest/restchannels_test.py new file mode 100644 index 00000000..43401d36 --- /dev/null +++ b/test/ably/sync/rest/restchannels_test.py @@ -0,0 +1,91 @@ +from collections.abc import Iterable + +import pytest + +from ably.sync import AblyException +from ably.sync.rest.channel import Channel, Channels, Presence +from ably.sync.util.crypto import generate_random_key + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import BaseAsyncTestCase + + +# makes no request, no need to use different protocols +class TestChannels(BaseAsyncTestCase): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def test_rest_channels_attr(self): + assert hasattr(self.ably, 'channels') + assert isinstance(self.ably.channels, Channels) + + def test_channels_get_returns_new_or_existing(self): + channel = self.ably.channels.get('new_channel') + assert isinstance(channel, Channel) + channel_same = self.ably.channels.get('new_channel') + assert channel is channel_same + + def test_channels_get_returns_new_with_options(self): + key = generate_random_key() + channel = self.ably.channels.get('new_channel', cipher={'key': key}) + assert isinstance(channel, Channel) + assert channel.cipher.secret_key is key + + def test_channels_get_updates_existing_with_options(self): + key = generate_random_key() + channel = self.ably.channels.get('new_channel', cipher={'key': key}) + assert channel.cipher is not None + + channel_same = self.ably.channels.get('new_channel', cipher=None) + assert channel is channel_same + assert channel.cipher is None + + def test_channels_get_doesnt_updates_existing_with_none_options(self): + key = generate_random_key() + channel = self.ably.channels.get('new_channel', cipher={'key': key}) + assert channel.cipher is not None + + channel_same = self.ably.channels.get('new_channel') + assert channel is channel_same + assert channel.cipher is not None + + def test_channels_in(self): + assert 'new_channel' not in self.ably.channels + self.ably.channels.get('new_channel') + new_channel_2 = self.ably.channels.get('new_channel_2') + assert 'new_channel' in self.ably.channels + assert new_channel_2 in self.ably.channels + + def test_channels_iteration(self): + channel_names = ['channel_{}'.format(i) for i in range(5)] + [self.ably.channels.get(name) for name in channel_names] + + assert isinstance(self.ably.channels, Iterable) + for name, channel in zip(channel_names, self.ably.channels): + assert isinstance(channel, Channel) + assert name == channel.name + + # RSN4a, RSN4b + def test_channels_release(self): + self.ably.channels.get('new_channel') + self.ably.channels.release('new_channel') + self.ably.channels.release('new_channel') + + def test_channel_has_presence(self): + channel = self.ably.channels.get('new_channnel') + assert channel.presence + assert isinstance(channel.presence, Presence) + + def test_without_permissions(self): + key = self.test_vars["keys"][2] + ably = TestApp.get_ably_rest(key=key["key_str"]) + with pytest.raises(AblyException) as excinfo: + ably.channels['test_publish_without_permission'].publish('foo', 'woop') + + assert 'not permitted' in excinfo.value.message + ably.close() diff --git a/test/ably/sync/rest/restchannelstatus_test.py b/test/ably/sync/rest/restchannelstatus_test.py new file mode 100644 index 00000000..5d281221 --- /dev/null +++ b/test/ably/sync/rest/restchannelstatus_test.py @@ -0,0 +1,47 @@ +import logging + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +class TestRestChannelStatus(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def test_channel_status(self): + channel_name = self.get_channel_name('test_channel_status') + channel = self.ably.channels[channel_name] + + channel_status = channel.status() + + assert channel_status is not None, "Expected non-None channel_status" + assert channel_name == channel_status.channel_id, "Expected channel name to match" + assert channel_status.status.is_active is True, "Expected is_active to be True" + assert isinstance(channel_status.status.occupancy.metrics.publishers, int) and\ + channel_status.status.occupancy.metrics.publishers >= 0,\ + "Expected publishers to be a non-negative int" + assert isinstance(channel_status.status.occupancy.metrics.connections, int) and\ + channel_status.status.occupancy.metrics.connections >= 0,\ + "Expected connections to be a non-negative int" + assert isinstance(channel_status.status.occupancy.metrics.subscribers, int) and\ + channel_status.status.occupancy.metrics.subscribers >= 0,\ + "Expected subscribers to be a non-negative int" + assert isinstance(channel_status.status.occupancy.metrics.presence_members, int) and\ + channel_status.status.occupancy.metrics.presence_members >= 0,\ + "Expected presence_members to be a non-negative int" + assert isinstance(channel_status.status.occupancy.metrics.presence_connections, int) and\ + channel_status.status.occupancy.metrics.presence_connections >= 0,\ + "Expected presence_connections to be a non-negative int" + assert isinstance(channel_status.status.occupancy.metrics.presence_subscribers, int) and\ + channel_status.status.occupancy.metrics.presence_subscribers >= 0,\ + "Expected presence_subscribers to be a non-negative int" diff --git a/test/ably/sync/rest/restcrypto_test.py b/test/ably/sync/rest/restcrypto_test.py new file mode 100644 index 00000000..3dd89bc2 --- /dev/null +++ b/test/ably/sync/rest/restcrypto_test.py @@ -0,0 +1,264 @@ +# import json +# import os +# import logging +# import base64 +# +# import pytest +# +# from ably import AblyException +# from ably.types.message import Message +# from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params +# +# from Crypto import Random +# +# from test.ably.testapp import TestApp +# from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase +# +# log = logging.getLogger(__name__) +# +# +# class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): +# +# async def asyncSetUp(self): +# self.test_vars = await TestApp.get_test_vars() +# self.ably = await TestApp.get_ably_rest() +# self.ably2 = await TestApp.get_ably_rest() +# +# async def asyncTearDown(self): +# await self.ably.close() +# await self.ably2.close() +# +# def per_protocol_setup(self, use_binary_protocol): +# # This will be called every test that vary by protocol for each protocol +# self.ably.options.use_binary_protocol = use_binary_protocol +# self.ably2.options.use_binary_protocol = use_binary_protocol +# self.use_binary_protocol = use_binary_protocol +# +# @dont_vary_protocol +# def test_cbc_channel_cipher(self): +# key = ( +# b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' +# b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') +# +# iv = ( +# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' +# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') +# +# log.debug("KEY_LEN: %d" % len(key)) +# log.debug("IV_LEN: %d" % len(iv)) +# cipher = get_cipher({'key': key, 'iv': iv}) +# +# plaintext = b"The quick brown fox" +# expected_ciphertext = ( +# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' +# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' +# b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' +# b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' +# b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' +# b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') +# +# actual_ciphertext = cipher.encrypt(plaintext) +# +# assert expected_ciphertext == actual_ciphertext +# +# async def test_crypto_publish(self): +# channel_name = self.get_channel_name('persisted:crypto_publish_text') +# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# history = await publish0.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String)" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_publish_256(self): +# rndfile = Random.new() +# key = rndfile.read(32) +# channel_name = 'persisted:crypto_publish_text_256' +# channel_name += '_bin' if self.use_binary_protocol else '_text' +# +# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# history = await publish0.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String)" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_publish_key_mismatch(self): +# channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') +# +# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# with pytest.raises(AblyException) as excinfo: +# await rx_channel.history() +# +# message = excinfo.value.message +# assert 'invalid-padding' == message or "codec can't decode" in message +# +# async def test_crypto_send_unencrypted(self): +# channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') +# publish0 = self.ably.channels[channel_name] +# +# await publish0.publish("publish3", "This is a string message payload") +# await publish0.publish("publish4", b"This is a byte[] message payload") +# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) +# await publish0.publish("publish6", ["This is a JSONArray message payload"]) +# +# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) +# +# history = await rx_channel.history() +# messages = history.items +# assert messages is not None, "Expected non-None messages" +# assert 4 == len(messages), "Expected 4 messages" +# +# message_contents = dict((m.name, m.data) for m in messages) +# log.debug("message_contents: %s" % str(message_contents)) +# +# assert "This is a string message payload" == message_contents["publish3"],\ +# "Expect publish3 to be expected String" +# +# assert b"This is a byte[] message payload" == message_contents["publish4"],\ +# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) +# +# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ +# "Expect publish5 to be expected JSONObject" +# +# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ +# "Expect publish6 to be expected JSONObject" +# +# async def test_crypto_encrypted_unhandled(self): +# channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') +# key = b'0123456789abcdef' +# data = 'foobar' +# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) +# +# await publish0.publish("publish0", data) +# +# rx_channel = self.ably2.channels[channel_name] +# history = await rx_channel.history() +# message = history.items[0] +# cipher = get_cipher(get_default_params({'key': key})) +# assert cipher.decrypt(message.data).decode() == data +# assert message.encoding == 'utf-8/cipher+aes-128-cbc' +# +# @dont_vary_protocol +# def test_cipher_params(self): +# params = CipherParams(secret_key='0123456789abcdef') +# assert params.algorithm == 'AES' +# assert params.mode == 'CBC' +# assert params.key_length == 128 +# +# params = CipherParams(secret_key='0123456789abcdef' * 2) +# assert params.algorithm == 'AES' +# assert params.mode == 'CBC' +# assert params.key_length == 256 +# +# +# class AbstractTestCryptoWithFixture: +# +# @classmethod +# def setUpClass(cls): +# resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file +# with open(resources_path, 'r') as f: +# cls.fixture = json.loads(f.read()) +# cls.params = { +# 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), +# 'mode': cls.fixture['mode'], +# 'algorithm': cls.fixture['algorithm'], +# 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), +# } +# cls.cipher_params = CipherParams(**cls.params) +# cls.cipher = get_cipher(cls.cipher_params) +# cls.items = cls.fixture['items'] +# +# def get_encoded(self, encoded_item): +# if encoded_item.get('encoding') == 'base64': +# return base64.b64decode(encoded_item['data'].encode('ascii')) +# elif encoded_item.get('encoding') == 'json': +# return json.loads(encoded_item['data']) +# return encoded_item['data'] +# +# # TM3 +# def test_decode(self): +# for item in self.items: +# assert item['encoded']['name'] == item['encrypted']['name'] +# message = Message.from_encoded(item['encrypted'], self.cipher) +# assert message.encoding == '' +# expected_data = self.get_encoded(item['encoded']) +# assert expected_data == message.data +# +# # TM3 +# def test_decode_array(self): +# items_encrypted = [item['encrypted'] for item in self.items] +# messages = Message.from_encoded_array(items_encrypted, self.cipher) +# for i, message in enumerate(messages): +# assert message.encoding == '' +# expected_data = self.get_encoded(self.items[i]['encoded']) +# assert expected_data == message.data +# +# def test_encode(self): +# for item in self.items: +# # need to reset iv +# self.cipher_params = CipherParams(**self.params) +# self.cipher = get_cipher(self.cipher_params) +# data = self.get_encoded(item['encoded']) +# expected = item['encrypted'] +# message = Message(item['encoded']['name'], data) +# message.encrypt(self.cipher) +# as_dict = message.as_dict() +# assert as_dict['data'] == expected['data'] +# assert as_dict['encoding'] == expected['encoding'] +# +# +# class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): +# fixture_file = 'crypto-data-128.json' +# +# +# class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): +# fixture_file = 'crypto-data-256.json' diff --git a/test/ably/sync/rest/resthttp_test.py b/test/ably/sync/rest/resthttp_test.py new file mode 100644 index 00000000..8b8fe771 --- /dev/null +++ b/test/ably/sync/rest/resthttp_test.py @@ -0,0 +1,229 @@ +import base64 +import re +import time + +import httpx +import mock +import pytest +from urllib.parse import urljoin + +import respx +from httpx import Response + +from ably.sync import AblyRest +from ably.sync.transport.defaults import Defaults +from ably.sync.types.options import Options +from ably.sync.util.exceptions import AblyException +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import BaseAsyncTestCase + + +class TestRestHttp(BaseAsyncTestCase): + def test_max_retry_attempts_and_timeouts_defaults(self): + ably = AblyRest(token="foo") + assert 'http_open_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS + assert 'http_request_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS + + with mock.patch('httpx.AsyncClient.send', side_effect=httpx.RequestError('')) as send_mock: + with pytest.raises(httpx.RequestError): + ably.http.make_request('GET', '/', version=Defaults.protocol_version, skip_auth=True) + + assert send_mock.call_count == Defaults.http_max_retry_count + assert send_mock.call_args == mock.call(mock.ANY) + ably.close() + + def test_cumulative_timeout(self): + ably = AblyRest(token="foo") + assert 'http_max_retry_duration' in ably.http.CONNECTION_RETRY_DEFAULTS + + ably.options.http_max_retry_duration = 0.5 + + def sleep_and_raise(*args, **kwargs): + time.sleep(0.51) + raise httpx.TimeoutException('timeout') + + with mock.patch('httpx.AsyncClient.send', side_effect=sleep_and_raise) as send_mock: + with pytest.raises(httpx.TimeoutException): + ably.http.make_request('GET', '/', skip_auth=True) + + assert send_mock.call_count == 1 + ably.close() + + def test_host_fallback(self): + ably = AblyRest(token="foo") + + def make_url(host): + base_url = "%s://%s:%d" % (ably.http.preferred_scheme, + host, + ably.http.preferred_port) + return urljoin(base_url, '/') + + with mock.patch('httpx.Request', wraps=httpx.Request) as request_mock: + with mock.patch('httpx.AsyncClient.send', side_effect=httpx.RequestError('')) as send_mock: + with pytest.raises(httpx.RequestError): + ably.http.make_request('GET', '/', skip_auth=True) + + assert send_mock.call_count == Defaults.http_max_retry_count + + expected_urls_set = { + make_url(host) + for host in Options(http_max_retry_count=10).get_rest_hosts() + } + for ((_, url), _) in request_mock.call_args_list: + assert url in expected_urls_set + expected_urls_set.remove(url) + + expected_hosts_set = set(Options(http_max_retry_count=10).get_rest_hosts()) + for (prep_request_tuple, _) in send_mock.call_args_list: + assert prep_request_tuple[0].headers.get('host') in expected_hosts_set + expected_hosts_set.remove(prep_request_tuple[0].headers.get('host')) + ably.close() + + @respx.mock + def test_no_host_fallback_nor_retries_if_custom_host(self): + custom_host = 'example.org' + ably = AblyRest(token="foo", rest_host=custom_host) + + mock_route = respx.get("https://example.org").mock(side_effect=httpx.RequestError('')) + + with pytest.raises(httpx.RequestError): + ably.http.make_request('GET', '/', skip_auth=True) + + assert mock_route.call_count == 1 + assert respx.calls.call_count == 1 + + ably.close() + + # RSC15f + def test_cached_fallback(self): + timeout = 2000 + ably = TestApp.get_ably_rest(fallback_retry_timeout=timeout) + host = ably.options.get_rest_host() + + state = {'errors': 0} + client = httpx.Client(http2=True) + send = client.send + + def side_effect(*args, **kwargs): + if args[1].url.host == host: + state['errors'] += 1 + raise RuntimeError + return send(args[1]) + + with mock.patch('httpx.AsyncClient.send', side_effect=side_effect, autospec=True): + # The main host is called and there's an error + ably.time() + assert state['errors'] == 1 + + # The cached host is used: no error + ably.time() + ably.time() + ably.time() + assert state['errors'] == 1 + + # The cached host has expired, we've an error again + time.sleep(timeout / 1000.0) + ably.time() + assert state['errors'] == 2 + + client.close() + ably.close() + + @respx.mock + def test_no_retry_if_not_500_to_599_http_code(self): + default_host = Options().get_rest_host() + ably = AblyRest(token="foo") + + default_url = "%s://%s:%d/" % ( + ably.http.preferred_scheme, + default_host, + ably.http.preferred_port) + + mock_response = httpx.Response(600, json={'message': "", 'status_code': 600, 'code': 50500}) + + mock_route = respx.get(default_url).mock(return_value=mock_response) + + with pytest.raises(AblyException): + ably.http.make_request('GET', '/', skip_auth=True) + + assert mock_route.call_count == 1 + assert respx.calls.call_count == 1 + + ably.close() + + def test_500_errors(self): + """ + Raise error if all the servers reply with a 5xx error. + https://github.com/ably/ably-python/issues/160 + """ + + ably = AblyRest(token="foo") + + def raise_ably_exception(*args, **kwargs): + raise AblyException(message="", status_code=500, code=50000) + + with mock.patch('httpx.Request', wraps=httpx.Request): + with mock.patch('ably.util.exceptions.AblyException.raise_for_response', + side_effect=raise_ably_exception) as send_mock: + with pytest.raises(AblyException): + ably.http.make_request('GET', '/', skip_auth=True) + + assert send_mock.call_count == 3 + ably.close() + + def test_custom_http_timeouts(self): + ably = AblyRest( + token="foo", http_request_timeout=30, http_open_timeout=8, + http_max_retry_count=6, http_max_retry_duration=20) + + assert ably.http.http_request_timeout == 30 + assert ably.http.http_open_timeout == 8 + assert ably.http.http_max_retry_count == 6 + assert ably.http.http_max_retry_duration == 20 + + # RSC7a, RSC7b + def test_request_headers(self): + ably = TestApp.get_ably_rest() + r = ably.http.make_request('HEAD', '/time', skip_auth=True) + + # API + assert 'X-Ably-Version' in r.request.headers + assert r.request.headers['X-Ably-Version'] == '3' + + # Agent + assert 'Ably-Agent' in r.request.headers + expr = r"^ably-python\/\d.\d.\d(-beta\.\d)? python\/\d.\d+.\d+$" + assert re.search(expr, r.request.headers['Ably-Agent']) + ably.close() + + # RSC7c + def test_add_request_ids(self): + # With request id + ably = TestApp.get_ably_rest(add_request_ids=True) + r = ably.http.make_request('HEAD', '/time', skip_auth=True) + assert 'request_id' in r.request.url.params + request_id1 = r.request.url.params['request_id'] + assert len(base64.urlsafe_b64decode(request_id1)) == 12 + + # With request id and new request + r = ably.http.make_request('HEAD', '/time', skip_auth=True) + assert 'request_id' in r.request.url.params + request_id2 = r.request.url.params['request_id'] + assert len(base64.urlsafe_b64decode(request_id2)) == 12 + assert request_id1 != request_id2 + ably.close() + + # With request id and new request + ably = TestApp.get_ably_rest() + r = ably.http.make_request('HEAD', '/time', skip_auth=True) + assert 'request_id' not in r.request.url.params + ably.close() + + def test_request_over_http2(self): + url = 'https://www.example.com' + respx.get(url).mock(return_value=Response(status_code=200)) + + ably = TestApp.get_ably_rest(rest_host=url) + r = ably.http.make_request('GET', url, skip_auth=True) + assert r.http_version == 'HTTP/2' + ably.close() diff --git a/test/ably/sync/rest/restinit_test.py b/test/ably/sync/rest/restinit_test.py new file mode 100644 index 00000000..84743360 --- /dev/null +++ b/test/ably/sync/rest/restinit_test.py @@ -0,0 +1,227 @@ +from mock import patch +import pytest +from httpx import Client + +from ably.sync import AblyRest +from ably.sync import AblyException +from ably.sync.transport.defaults import Defaults +from ably.sync.types.tokendetails import TokenDetails + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + + +class TestRestInit(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + + @dont_vary_protocol + def test_key_only(self): + ably = AblyRest(key=self.test_vars["keys"][0]["key_str"]) + assert ably.options.key_name == self.test_vars["keys"][0]["key_name"], "Key name does not match" + assert ably.options.key_secret == self.test_vars["keys"][0]["key_secret"], "Key secret does not match" + + def per_protocol_setup(self, use_binary_protocol): + self.use_binary_protocol = use_binary_protocol + + @dont_vary_protocol + def test_with_token(self): + ably = AblyRest(token="foo") + assert ably.options.auth_token == "foo", "Token not set at options" + + @dont_vary_protocol + def test_with_token_details(self): + td = TokenDetails() + ably = AblyRest(token_details=td) + assert ably.options.token_details is td + + @dont_vary_protocol + def test_with_options_token_callback(self): + def token_callback(**params): + return "this_is_not_really_a_token_request" + AblyRest(auth_callback=token_callback) + + @dont_vary_protocol + def test_ambiguous_key_raises_value_error(self): + with pytest.raises(ValueError, match="mutually exclusive"): + AblyRest(key=self.test_vars["keys"][0]["key_str"], key_name='x') + with pytest.raises(ValueError, match="mutually exclusive"): + AblyRest(key=self.test_vars["keys"][0]["key_str"], key_secret='x') + + @dont_vary_protocol + def test_with_key_name_or_secret_only(self): + with pytest.raises(ValueError, match="key is missing"): + AblyRest(key_name='x') + with pytest.raises(ValueError, match="key is missing"): + AblyRest(key_secret='x') + + @dont_vary_protocol + def test_with_key_name_and_secret(self): + ably = AblyRest(key_name="foo", key_secret="bar") + assert ably.options.key_name == "foo", "Key name does not match" + assert ably.options.key_secret == "bar", "Key secret does not match" + + @dont_vary_protocol + def test_with_options_auth_url(self): + AblyRest(auth_url='not_really_an_url') + + # RSC11 + @dont_vary_protocol + def test_rest_host_and_environment(self): + # rest host + ably = AblyRest(token='foo', rest_host="some.other.host") + assert "some.other.host" == ably.options.rest_host, "Unexpected host mismatch" + + # environment: production + ably = AblyRest(token='foo', environment="production") + host = ably.options.get_rest_host() + assert "rest.ably.io" == host, "Unexpected host mismatch %s" % host + + # environment: other + ably = AblyRest(token='foo', environment="sandbox") + host = ably.options.get_rest_host() + assert "sandbox-rest.ably.io" == host, "Unexpected host mismatch %s" % host + + # both, as per #TO3k2 + with pytest.raises(ValueError): + ably = AblyRest(token='foo', rest_host="some.other.host", + environment="some.other.environment") + + # RSC15 + @dont_vary_protocol + def test_fallback_hosts(self): + # Specify the fallback_hosts (RSC15a) + fallback_hosts = [ + ['fallback1.com', 'fallback2.com'], + [], + ] + + # Fallback hosts specified (RSC15g1) + for aux in fallback_hosts: + ably = AblyRest(token='foo', fallback_hosts=aux) + assert sorted(aux) == sorted(ably.options.get_fallback_rest_hosts()) + + # Specify environment (RSC15g2) + ably = AblyRest(token='foo', environment='sandbox', http_max_retry_count=10) + assert sorted(Defaults.get_environment_fallback_hosts('sandbox')) == sorted( + ably.options.get_fallback_rest_hosts()) + + # Fallback hosts and environment not specified (RSC15g3) + ably = AblyRest(token='foo', http_max_retry_count=10) + assert sorted(Defaults.fallback_hosts) == sorted(ably.options.get_fallback_rest_hosts()) + + # RSC15f + ably = AblyRest(token='foo') + assert 600000 == ably.options.fallback_retry_timeout + ably = AblyRest(token='foo', fallback_retry_timeout=1000) + assert 1000 == ably.options.fallback_retry_timeout + + @dont_vary_protocol + def test_specified_realtime_host(self): + ably = AblyRest(token='foo', realtime_host="some.other.host") + assert "some.other.host" == ably.options.realtime_host, "Unexpected host mismatch" + + @dont_vary_protocol + def test_specified_port(self): + ably = AblyRest(token='foo', port=9998, tls_port=9999) + assert 9999 == Defaults.get_port(ably.options),\ + "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port + + @dont_vary_protocol + def test_specified_non_tls_port(self): + ably = AblyRest(token='foo', port=9998, tls=False) + assert 9998 == Defaults.get_port(ably.options),\ + "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port + + @dont_vary_protocol + def test_specified_tls_port(self): + ably = AblyRest(token='foo', tls_port=9999, tls=True) + assert 9999 == Defaults.get_port(ably.options),\ + "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port + + @dont_vary_protocol + def test_tls_defaults_to_true(self): + ably = AblyRest(token='foo') + assert ably.options.tls, "Expected encryption to default to true" + assert Defaults.tls_port == Defaults.get_port(ably.options), "Unexpected port mismatch" + + @dont_vary_protocol + def test_tls_can_be_disabled(self): + ably = AblyRest(token='foo', tls=False) + assert not ably.options.tls, "Expected encryption to be False" + assert Defaults.port == Defaults.get_port(ably.options), "Unexpected port mismatch" + + @dont_vary_protocol + def test_with_no_params(self): + with pytest.raises(ValueError): + AblyRest() + + @dont_vary_protocol + def test_with_no_auth_params(self): + with pytest.raises(ValueError): + AblyRest(port=111) + + # RSA10k + def test_query_time_param(self): + ably = TestApp.get_ably_rest(query_time=True, + use_binary_protocol=self.use_binary_protocol) + + timestamp = ably.auth._timestamp + with patch('ably.rest.rest.AblyRest.time', wraps=ably.time) as server_time,\ + patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + ably.auth.request_token() + assert local_time.call_count == 1 + assert server_time.call_count == 1 + ably.auth.request_token() + assert local_time.call_count == 2 + assert server_time.call_count == 1 + + ably.close() + + @dont_vary_protocol + def test_requests_over_https_production(self): + ably = AblyRest(token='token') + assert 'https://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) + assert ably.http.preferred_port == 443 + + @dont_vary_protocol + def test_requests_over_http_production(self): + ably = AblyRest(token='token', tls=False) + assert 'http://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) + assert ably.http.preferred_port == 80 + + @dont_vary_protocol + def test_request_basic_auth_over_http_fails(self): + ably = AblyRest(key_secret='foo', key_name='bar', tls=False) + + with pytest.raises(AblyException) as excinfo: + ably.http.get('/time', skip_auth=False) + + assert 401 == excinfo.value.status_code + assert 40103 == excinfo.value.code + assert 'Cannot use Basic Auth over non-TLS connections' == excinfo.value.message + + @dont_vary_protocol + def test_environment(self): + ably = AblyRest(token='token', environment='custom') + with patch.object(Client, 'send', wraps=ably.http._Http__client.send) as get_mock: + try: + ably.time() + except AblyException: + pass + request = get_mock.call_args_list[0][0][0] + assert request.url == 'https://custom-rest.ably.io:443/time' + + ably.close() + + @dont_vary_protocol + def test_accepts_custom_http_timeouts(self): + ably = AblyRest( + token="foo", http_request_timeout=30, http_open_timeout=8, + http_max_retry_count=6, http_max_retry_duration=20) + + assert ably.options.http_request_timeout == 30 + assert ably.options.http_open_timeout == 8 + assert ably.options.http_max_retry_count == 6 + assert ably.options.http_max_retry_duration == 20 diff --git a/test/ably/sync/rest/restpaginatedresult_test.py b/test/ably/sync/rest/restpaginatedresult_test.py new file mode 100644 index 00000000..348e6b47 --- /dev/null +++ b/test/ably/sync/rest/restpaginatedresult_test.py @@ -0,0 +1,91 @@ +import respx +from httpx import Response + +from ably.sync.http.paginatedresult import PaginatedResult + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import BaseAsyncTestCase + + +class TestPaginatedResult(BaseAsyncTestCase): + + def get_response_callback(self, headers, body, status): + def callback(request): + res = request.url.params.get('page') + if res: + return Response( + status_code=status, + headers=headers, + content='[{"page": %i}]' % int(res) + ) + + return Response( + status_code=status, + headers=headers, + content=body + ) + + return callback + + def setUp(self): + self.ably = TestApp.get_ably_rest(use_binary_protocol=False) + # Mocked responses + # without specific headers + self.mocked_api = respx.mock(base_url='http://rest.ably.io') + self.ch1_route = self.mocked_api.get('/channels/channel_name/ch1') + self.ch1_route.return_value = Response( + headers={'content-type': 'application/json'}, + status_code=200, + content='[{"id": 0}, {"id": 1}]', + ) + # with headers + self.ch2_route = self.mocked_api.get('/channels/channel_name/ch2') + self.ch2_route.side_effect = self.get_response_callback( + headers={ + 'content-type': 'application/json', + 'link': + '; rel="first",' + ' ; rel="next"' + }, + body='[{"id": 0}, {"id": 1}]', + status=200 + ) + # start intercepting requests + self.mocked_api.start() + + self.paginated_result = PaginatedResult.paginated_query( + self.ably.http, + url='http://rest.ably.io/channels/channel_name/ch1', + response_processor=lambda response: response.to_native()) + self.paginated_result_with_headers = PaginatedResult.paginated_query( + self.ably.http, + url='http://rest.ably.io/channels/channel_name/ch2', + response_processor=lambda response: response.to_native()) + + def tearDown(self): + self.mocked_api.stop() + self.mocked_api.reset() + self.ably.close() + + def test_items(self): + assert len(self.paginated_result.items) == 2 + + def test_with_no_headers(self): + assert self.paginated_result.first() is None + assert self.paginated_result.next() is None + assert self.paginated_result.is_last() + + def test_with_next(self): + pag = self.paginated_result_with_headers + assert pag.has_next() + assert not pag.is_last() + + def test_first(self): + pag = self.paginated_result_with_headers + pag = pag.first() + assert pag.items[0]['page'] == 1 + + def test_next(self): + pag = self.paginated_result_with_headers + pag = pag.next() + assert pag.items[0]['page'] == 2 diff --git a/test/ably/sync/rest/restpresence_test.py b/test/ably/sync/rest/restpresence_test.py new file mode 100644 index 00000000..d3c81ab1 --- /dev/null +++ b/test/ably/sync/rest/restpresence_test.py @@ -0,0 +1,213 @@ +from datetime import datetime, timedelta + +import pytest +import respx + +from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.types.presence import PresenceMessage + +from test.ably.sync.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseAsyncTestCase +from test.ably.sync.testapp import TestApp + + +class TestPresence(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.test_vars = TestApp.get_test_vars() + self.ably = TestApp.get_ably_rest() + self.channel = self.ably.channels.get('persisted:presence_fixtures') + self.ably.options.use_binary_protocol = True + + def tearDown(self): + self.ably.channels.release('persisted:presence_fixtures') + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + def test_channel_presence_get(self): + presence_page = self.channel.presence.get() + assert isinstance(presence_page, PaginatedResult) + assert len(presence_page.items) == 6 + member = presence_page.items[0] + assert isinstance(member, PresenceMessage) + assert member.action + assert member.id + assert member.client_id + assert member.data + assert member.connection_id + assert member.timestamp + + def test_channel_presence_history(self): + presence_history = self.channel.presence.history() + assert isinstance(presence_history, PaginatedResult) + assert len(presence_history.items) == 6 + member = presence_history.items[0] + assert isinstance(member, PresenceMessage) + assert member.action + assert member.id + assert member.client_id + assert member.data + assert member.connection_id + assert member.timestamp + assert member.encoding + + def test_presence_get_encoded(self): + presence_history = self.channel.presence.history() + assert presence_history.items[-1].data == "true" + assert presence_history.items[-2].data == "24" + assert presence_history.items[-3].data == "This is a string clientData payload" + # this one doesn't have encoding field + assert presence_history.items[-4].data == '{ "test": "This is a JSONObject clientData payload"}' + assert presence_history.items[-5].data == {"example": {"json": "Object"}} + + def test_timestamp_is_datetime(self): + presence_page = self.channel.presence.get() + member = presence_page.items[0] + assert isinstance(member.timestamp, datetime) + + def test_presence_message_has_correct_member_key(self): + presence_page = self.channel.presence.get() + member = presence_page.items[0] + + assert member.member_key == "%s:%s" % (member.connection_id, member.client_id) + + def presence_mock_url(self): + kwargs = { + 'scheme': 'https' if self.test_vars['tls'] else 'http', + 'host': self.test_vars['host'] + } + port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] + if port == 80: + kwargs['port_sufix'] = '' + else: + kwargs['port_sufix'] = ':' + str(port) + url = '{scheme}://{host}{port_sufix}/channels/persisted%3Apresence_fixtures/presence' + return url.format(**kwargs) + + def history_mock_url(self): + kwargs = { + 'scheme': 'https' if self.test_vars['tls'] else 'http', + 'host': self.test_vars['host'] + } + port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] + if port == 80: + kwargs['port_sufix'] = '' + else: + kwargs['port_sufix'] = ':' + str(port) + url = '{scheme}://{host}{port_sufix}/channels/persisted%3Apresence_fixtures/presence/history' + return url.format(**kwargs) + + @dont_vary_protocol + @respx.mock + def test_get_presence_default_limit(self): + url = self.presence_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.get() + assert 'limit' not in respx.calls[0].request.url.params.keys() + + @dont_vary_protocol + @respx.mock + def test_get_presence_with_limit(self): + url = self.presence_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.get(300) + assert '300' == respx.calls[0].request.url.params.get('limit') + + @dont_vary_protocol + @respx.mock + def test_get_presence_max_limit_is_1000(self): + url = self.presence_mock_url() + self.respx_add_empty_msg_pack(url) + with pytest.raises(ValueError): + self.channel.presence.get(5000) + + @dont_vary_protocol + @respx.mock + def test_history_default_limit(self): + url = self.history_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.history() + assert 'limit' not in respx.calls[0].request.url.params.keys() + + @dont_vary_protocol + @respx.mock + def test_history_with_limit(self): + url = self.history_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.history(300) + assert '300' == respx.calls[0].request.url.params.get('limit') + + @dont_vary_protocol + @respx.mock + def test_history_with_direction(self): + url = self.history_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.history(direction='backwards') + assert 'backwards' == respx.calls[0].request.url.params.get('direction') + + @dont_vary_protocol + @respx.mock + def test_history_max_limit_is_1000(self): + url = self.history_mock_url() + self.respx_add_empty_msg_pack(url) + with pytest.raises(ValueError): + self.channel.presence.history(5000) + + @dont_vary_protocol + @respx.mock + def test_with_milisecond_start_end(self): + url = self.history_mock_url() + self.respx_add_empty_msg_pack(url) + self.channel.presence.history(start=100000, end=100001) + assert '100000' == respx.calls[0].request.url.params.get('start') + assert '100001' == respx.calls[0].request.url.params.get('end') + + @dont_vary_protocol + @respx.mock + def test_with_timedate_startend(self): + url = self.history_mock_url() + start = datetime(2015, 8, 15, 17, 11, 44, 706539) + start_ms = 1439658704706 + end = start + timedelta(hours=1) + end_ms = start_ms + (1000 * 60 * 60) + self.respx_add_empty_msg_pack(url) + self.channel.presence.history(start=start, end=end) + assert str(start_ms) in respx.calls[0].request.url.params.get('start') + assert str(end_ms) in respx.calls[0].request.url.params.get('end') + + @dont_vary_protocol + @respx.mock + def test_with_start_gt_end(self): + url = self.history_mock_url() + end = datetime(2015, 8, 15, 17, 11, 44, 706539) + start = end + timedelta(hours=1) + self.respx_add_empty_msg_pack(url) + with pytest.raises(ValueError, match="'end' parameter has to be greater than or equal to 'start'"): + self.channel.presence.history(start=start, end=end) + + +class TestPresenceCrypt(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + key = b'0123456789abcdef' + self.channel = self.ably.channels.get('persisted:presence_fixtures', cipher={'key': key}) + + def tearDown(self): + self.ably.channels.release('persisted:presence_fixtures') + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + def test_presence_history_encrypted(self): + presence_history = self.channel.presence.history() + assert presence_history.items[0].data == {'foo': 'bar'} + + def test_presence_get_encrypted(self): + messages = self.channel.presence.get() + messages = (msg for msg in messages.items if msg.client_id == 'client_encoded') + message = next(messages) + + assert message.data == {'foo': 'bar'} diff --git a/test/ably/sync/rest/restpush_test.py b/test/ably/sync/rest/restpush_test.py new file mode 100644 index 00000000..c1127d2e --- /dev/null +++ b/test/ably/sync/rest/restpush_test.py @@ -0,0 +1,398 @@ +import itertools +import random +import string +import time + +import pytest + +from ably.sync import AblyException, AblyAuthException +from ably.sync import DeviceDetails, PushChannelSubscription +from ably.sync.http.paginatedresult import PaginatedResult + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, BaseAsyncTestCase +from test.ably.sync.utils import new_dict, random_string, get_random_key + + +DEVICE_TOKEN = '740f4707bebcf74f9b7c25d48e3358945f6aa01da5ddb387462c7eaf61bb78ad' + + +class TestPush(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + + # Register several devices for later use + self.devices = {} + for i in range(10): + self.save_device() + + # Register several subscriptions for later use + self.channels = {'canpublish:test1': [], 'canpublish:test2': [], 'canpublish:test3': []} + for key, channel in zip(self.devices, itertools.cycle(self.channels)): + device = self.devices[key] + self.save_subscription(channel, device_id=device.id) + assert len(list(itertools.chain(*self.channels.values()))) == len(self.devices) + + def tearDown(self): + for key, channel in zip(self.devices, itertools.cycle(self.channels)): + device = self.devices[key] + self.remove_subscription(channel, device_id=device.id) + self.ably.push.admin.device_registrations.remove(device_id=device.id) + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + def get_client_id(self): + return random_string(12) + + def get_device_id(self): + return random_string(26, string.ascii_uppercase + string.digits) + + def gen_device_data(self, data=None, **kw): + if data is None: + data = { + 'id': self.get_device_id(), + 'clientId': self.get_client_id(), + 'platform': random.choice(['android', 'ios']), + 'formFactor': 'phone', + 'deviceSecret': 'test-secret', + 'push': { + 'recipient': { + 'transportType': 'apns', + 'deviceToken': DEVICE_TOKEN, + } + }, + } + else: + data = data.copy() + + data.update(kw) + return data + + def save_device(self, data=None, **kw): + """ + Helper method to register a device, to not have this code repeated + everywhere. Returns the input dict that was sent to Ably, and the + device details returned by Ably. + """ + data = self.gen_device_data(data, **kw) + device = self.ably.push.admin.device_registrations.save(data) + self.devices[device.id] = device + return device + + def remove_device(self, device_id): + result = self.ably.push.admin.device_registrations.remove(device_id) + self.devices.pop(device_id, None) + return result + + def remove_device_where(self, **kw): + remove_where = self.ably.push.admin.device_registrations.remove_where + result = remove_where(**kw) + + aux = {'deviceId': 'id', 'clientId': 'client_id'} + for device in list(self.devices.values()): + for key, value in kw.items(): + key = aux[key] + if getattr(device, key) == value: + del self.devices[device.id] + + return result + + def get_device(self): + key = get_random_key(self.devices) + return self.devices[key] + + def get_channel(self): + key = get_random_key(self.channels) + return key, self.channels[key] + + def save_subscription(self, channel, **kw): + """ + Helper method to register a device, to not have this code repeated + everywhere. Returns the input dict that was sent to Ably, and the + device details returned by Ably. + """ + subscription = PushChannelSubscription(channel, **kw) + subscription = self.ably.push.admin.channel_subscriptions.save(subscription) + self.channels.setdefault(channel, []).append(subscription) + return subscription + + def remove_subscription(self, channel, **kw): + subscription = PushChannelSubscription(channel, **kw) + subscription = self.ably.push.admin.channel_subscriptions.remove(subscription) + return subscription + + # RSH1a + def test_admin_publish(self): + recipient = {'clientId': 'ablyChannel'} + data = { + 'data': {'foo': 'bar'}, + } + + publish = self.ably.push.admin.publish + with pytest.raises(TypeError): + publish('ablyChannel', data) + with pytest.raises(TypeError): + publish(recipient, 25) + with pytest.raises(ValueError): + publish({}, data) + with pytest.raises(ValueError): + publish(recipient, {}) + + with pytest.raises(AblyException): + publish(recipient, {'xxx': 5}) + + assert publish(recipient, data) is None + + # RSH1b1 + def test_admin_device_registrations_get(self): + get = self.ably.push.admin.device_registrations.get + + # Not found + with pytest.raises(AblyException): + get('not-found') + + # Found + device = self.get_device() + device_details = get(device.id) + assert device_details.id == device.id + assert device_details.platform == device.platform + assert device_details.form_factor == device.form_factor + + # RSH1b2 + def test_admin_device_registrations_list(self): + list_devices = self.ably.push.admin.device_registrations.list + + list_response = list_devices() + assert type(list_response) is PaginatedResult + assert type(list_response.items) is list + assert type(list_response.items[0]) is DeviceDetails + + # limit + list_response = list_devices(limit=5000) + assert len(list_response.items) == len(self.devices) + list_response = list_devices(limit=2) + assert len(list_response.items) == 2 + + # Filter by device id + device = self.get_device() + list_response = list_devices(deviceId=device.id) + assert len(list_response.items) == 1 + list_response = list_devices(deviceId=self.get_device_id()) + assert len(list_response.items) == 0 + + # Filter by client id + list_response = list_devices(clientId=device.client_id) + assert len(list_response.items) == 1 + list_response = list_devices(clientId=self.get_client_id()) + assert len(list_response.items) == 0 + + # RSH1b3 + def test_admin_device_registrations_save(self): + # Create + data = self.gen_device_data() + device = self.save_device(data) + assert type(device) is DeviceDetails + + # Update + self.save_device(data, formFactor='tablet') + + # Invalid values + with pytest.raises(ValueError): + push = {'recipient': new_dict(data['push']['recipient'], transportType='xyz')} + self.save_device(data, push=push) + with pytest.raises(ValueError): + self.save_device(data, platform='native') + with pytest.raises(ValueError): + self.save_device(data, formFactor='fridge') + + # Fail + with pytest.raises(AblyException): + self.save_device(data, push={'color': 'red'}) + + # RSH1b4 + def test_admin_device_registrations_remove(self): + get = self.ably.push.admin.device_registrations.get + + device = self.get_device() + + # Remove + get_response = get(device.id) + assert get_response.id == device.id # Exists + remove_device_response = self.remove_device(device.id) + assert remove_device_response.status_code == 204 + with pytest.raises(AblyException): # Doesn't exist + get(device.id) + + # Remove again, it doesn't fail + remove_device_response = self.remove_device(device.id) + assert remove_device_response.status_code == 204 + + # RSH1b5 + def test_admin_device_registrations_remove_where(self): + get = self.ably.push.admin.device_registrations.get + + # Remove by device id + device = self.get_device() + foo_device = get(device.id) + assert foo_device.id == device.id # Exists + remove_foo_device_response = self.remove_device_where(deviceId=device.id) + assert remove_foo_device_response.status_code == 204 + with pytest.raises(AblyException): # Doesn't exist + get(device.id) + + # Remove by client id + device = self.get_device() + boo_device = get(device.id) + assert boo_device.id == device.id # Exists + remove_boo_device_response = self.remove_device_where(clientId=device.client_id) + assert remove_boo_device_response.status_code == 204 + # Doesn't exist (Deletion is async: wait up to a few seconds before giving up) + with pytest.raises(AblyException): + for i in range(5): + time.sleep(1) + get(device.id) + + # Remove with no matching params + remove_boo_device_response = self.remove_device_where(clientId=device.client_id) + assert remove_boo_device_response.status_code == 204 + + # # RSH1c1 + def test_admin_channel_subscriptions_list(self): + list_ = self.ably.push.admin.channel_subscriptions.list + + channel, subscriptions = self.get_channel() + + list_response = list_(channel=channel) + + assert type(list_response) is PaginatedResult + assert type(list_response.items) is list + assert type(list_response.items[0]) is PushChannelSubscription + + # limit + list_response = list_(channel=channel, limit=2) + assert len(list_response.items) == 2 + + list_response = list_(channel=channel, limit=5000) + assert len(list_response.items) == len(subscriptions) + + # Filter by device id + device_id = subscriptions[0].device_id + list_response = list_(channel=channel, deviceId=device_id) + assert len(list_response.items) == 1 + assert list_response.items[0].device_id == device_id + assert list_response.items[0].channel == channel + list_response = list_(channel=channel, deviceId=self.get_device_id()) + assert len(list_response.items) == 0 + + # Filter by client id + device = self.get_device() + list_response = list_(channel=channel, clientId=device.client_id) + assert len(list_response.items) == 0 + + # RSH1c2 + def test_admin_channels_list(self): + list_ = self.ably.push.admin.channel_subscriptions.list_channels + + list_response = list_() + assert type(list_response) is PaginatedResult + assert type(list_response.items) is list + assert type(list_response.items[0]) is str + + # limit + list_response = list_(limit=5000) + assert len(list_response.items) == len(self.channels) + list_response = list_(limit=1) + assert len(list_response.items) == 1 + + # RSH1c3 + def test_admin_channel_subscriptions_save(self): + save = self.ably.push.admin.channel_subscriptions.save + + # Subscribe + device = self.get_device() + channel = 'canpublish:testsave' + subscription = self.save_subscription(channel, device_id=device.id) + assert type(subscription) is PushChannelSubscription + assert subscription.channel == channel + assert subscription.device_id == device.id + assert subscription.client_id is None + + # Failures + client_id = self.get_client_id() + with pytest.raises(ValueError): + PushChannelSubscription(channel, device_id=device.id, client_id=client_id) + + subscription = PushChannelSubscription('notallowed', device_id=device.id) + with pytest.raises(AblyAuthException): + save(subscription) + + subscription = PushChannelSubscription(channel, device_id='notregistered') + with pytest.raises(AblyException): + save(subscription) + + # RSH1c4 + def test_admin_channel_subscriptions_remove(self): + save = self.ably.push.admin.channel_subscriptions.save + remove = self.ably.push.admin.channel_subscriptions.remove + list_ = self.ably.push.admin.channel_subscriptions.list + + channel = 'canpublish:testremove' + + # Subscribe device + device = self.get_device() + subscription = save(PushChannelSubscription(channel, device_id=device.id)) + list_response = list_(channel=channel) + assert device.id in (x.device_id for x in list_response.items) + remove_response = remove(subscription) + assert remove_response.status_code == 204 + list_response = list_(channel=channel) + assert device.id not in (x.device_id for x in list_response.items) + + # Subscribe client + client_id = self.get_client_id() + subscription = save(PushChannelSubscription(channel, client_id=client_id)) + list_response = list_(channel=channel) + assert client_id in (x.client_id for x in list_response.items) + remove_response = remove(subscription) + assert remove_response.status_code == 204 + list_response = list_(channel=channel) + assert client_id not in (x.client_id for x in list_response.items) + + # Remove again, it doesn't fail + remove_response = remove(subscription) + assert remove_response.status_code == 204 + + # RSH1c5 + def test_admin_channel_subscriptions_remove_where(self): + save = self.ably.push.admin.channel_subscriptions.save + remove = self.ably.push.admin.channel_subscriptions.remove_where + list_ = self.ably.push.admin.channel_subscriptions.list + + channel = 'canpublish:testremovewhere' + + # Subscribe device + device = self.get_device() + save(PushChannelSubscription(channel, device_id=device.id)) + list_response = list_(channel=channel) + assert device.id in (x.device_id for x in list_response.items) + remove_response = remove(channel=channel, device_id=device.id) + assert remove_response.status_code == 204 + list_response = list_(channel=channel) + assert device.id not in (x.device_id for x in list_response.items) + + # Subscribe client + client_id = self.get_client_id() + save(PushChannelSubscription(channel, client_id=client_id)) + list_response = list_(channel=channel) + assert client_id in (x.client_id for x in list_response.items) + remove_response = remove(channel=channel, client_id=client_id) + assert remove_response.status_code == 204 + list_response = list_(channel=channel) + assert client_id not in (x.client_id for x in list_response.items) + + # Remove again, it doesn't fail + remove_response = remove(channel=channel, client_id=client_id) + assert remove_response.status_code == 204 diff --git a/test/ably/sync/rest/restrequest_test.py b/test/ably/sync/rest/restrequest_test.py new file mode 100644 index 00000000..cad062c3 --- /dev/null +++ b/test/ably/sync/rest/restrequest_test.py @@ -0,0 +1,132 @@ +import httpx +import pytest +import respx + +from ably.sync import AblyRest +from ably.sync.http.paginatedresult import HttpPaginatedResponse +from ably.sync.transport.defaults import Defaults +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import BaseAsyncTestCase +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol + + +# RSC19 +class TestRestRequest(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.test_vars = TestApp.get_test_vars() + + # Populate the channel (using the new api) + self.channel = self.get_channel_name() + self.path = '/channels/%s/messages' % self.channel + for i in range(20): + body = {'name': 'event%s' % i, 'data': 'lorem ipsum %s' % i} + self.ably.request('POST', self.path, body=body, version=Defaults.protocol_version) + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def test_post(self): + body = {'name': 'test-post', 'data': 'lorem ipsum'} + result = self.ably.request('POST', self.path, body=body, version=Defaults.protocol_version) + + assert isinstance(result, HttpPaginatedResponse) # RSC19d + # HP3 + assert type(result.items) is list + assert len(result.items) == 1 + assert result.items[0]['channel'] == self.channel + assert 'messageId' in result.items[0] + + def test_get(self): + params = {'limit': 10, 'direction': 'forwards'} + result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) + + assert isinstance(result, HttpPaginatedResponse) # RSC19d + + # HP2 + assert isinstance(result.next(), HttpPaginatedResponse) + assert isinstance(result.first(), HttpPaginatedResponse) + + # HP3 + assert isinstance(result.items, list) + item = result.items[0] + assert isinstance(item, dict) + assert 'timestamp' in item + assert 'id' in item + assert item['name'] == 'event0' + assert item['data'] == 'lorem ipsum 0' + + assert result.status_code == 200 # HP4 + assert result.success is True # HP5 + assert result.error_code is None # HP6 + assert result.error_message is None # HP7 + assert isinstance(result.headers, list) # HP7 + + @dont_vary_protocol + def test_not_found(self): + result = self.ably.request('GET', '/not-found', version=Defaults.protocol_version) + assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert result.status_code == 404 # HP4 + assert result.success is False # HP5 + + @dont_vary_protocol + def test_error(self): + params = {'limit': 'abc'} + result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) + assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert result.status_code == 400 # HP4 + assert not result.success + assert result.error_code + assert result.error_message + + def test_headers(self): + key = 'X-Test' + value = 'lorem ipsum' + result = self.ably.request('GET', '/time', headers={key: value}, version=Defaults.protocol_version) + assert result.response.request.headers[key] == value + + # RSC19e + @dont_vary_protocol + def test_timeout(self): + # Timeout + timeout = 0.000001 + ably = AblyRest(token="foo", http_request_timeout=timeout) + assert ably.http.http_request_timeout == timeout + with pytest.raises(httpx.ReadTimeout): + ably.request('GET', '/time', version=Defaults.protocol_version) + ably.close() + + default_endpoint = 'https://sandbox-rest.ably.io/time' + fallback_host = 'sandbox-a-fallback.ably-realtime.com' + fallback_endpoint = f'https://{fallback_host}/time' + ably = TestApp.get_ably_rest(fallback_hosts=[fallback_host]) + with respx.mock: + default_route = respx.get(default_endpoint) + fallback_route = respx.get(fallback_endpoint) + headers = { + "Content-Type": "application/json" + } + default_route.side_effect = httpx.ConnectError('') + fallback_route.return_value = httpx.Response(200, headers=headers, text='[123]') + ably.request('GET', '/time', version=Defaults.protocol_version) + ably.close() + + # Bad host, no Fallback + ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], + rest_host='some.other.host', + port=self.test_vars["port"], + tls_port=self.test_vars["tls_port"], + tls=self.test_vars["tls"]) + with pytest.raises(httpx.ConnectError): + ably.request('GET', '/time', version=Defaults.protocol_version) + ably.close() + + def test_version(self): + version = "150" # chosen arbitrarily + result = self.ably.request('GET', '/time', "150") + assert result.response.request.headers["X-Ably-Version"] == version diff --git a/test/ably/sync/rest/reststats_test.py b/test/ably/sync/rest/reststats_test.py new file mode 100644 index 00000000..a621c927 --- /dev/null +++ b/test/ably/sync/rest/reststats_test.py @@ -0,0 +1,310 @@ +from datetime import datetime +from datetime import timedelta +import logging + +import pytest + +from ably.sync.types.stats import Stats +from ably.sync.util.exceptions import AblyException +from ably.sync.http.paginatedresult import PaginatedResult + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +class TestRestAppStatsSetup: + __stats_added = False + + def get_params(self): + return { + 'start': self.last_interval, + 'end': self.last_interval, + 'unit': 'minute', + 'limit': 1 + } + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.ably_text = TestApp.get_ably_rest(use_binary_protocol=False) + + self.last_year = datetime.now().year - 1 + self.previous_year = datetime.now().year - 2 + self.last_interval = datetime(self.last_year, 2, 3, 15, 5) + self.previous_interval = datetime(self.previous_year, 2, 3, 15, 5) + previous_year_stats = 120 + stats = [ + { + 'intervalId': Stats.to_interval_id(self.last_interval - timedelta(minutes=2), + 'minute'), + 'inbound': {'realtime': {'messages': {'count': 50, 'data': 5000}}}, + 'outbound': {'realtime': {'messages': {'count': 20, 'data': 2000}}} + }, + { + 'intervalId': Stats.to_interval_id(self.last_interval - timedelta(minutes=1), + 'minute'), + 'inbound': {'realtime': {'messages': {'count': 60, 'data': 6000}}}, + 'outbound': {'realtime': {'messages': {'count': 10, 'data': 1000}}} + }, + { + 'intervalId': Stats.to_interval_id(self.last_interval, 'minute'), + 'inbound': {'realtime': {'messages': {'count': 70, 'data': 7000}}}, + 'outbound': {'realtime': {'messages': {'count': 40, 'data': 4000}}}, + 'persisted': {'presence': {'count': 20, 'data': 2000}}, + 'connections': {'tls': {'peak': 20, 'opened': 10}}, + 'channels': {'peak': 50, 'opened': 30}, + 'apiRequests': {'succeeded': 50, 'failed': 10}, + 'tokenRequests': {'succeeded': 60, 'failed': 20}, + } + ] + + previous_stats = [] + for i in range(previous_year_stats): + previous_stats.append( + { + 'intervalId': Stats.to_interval_id(self.previous_interval - timedelta(minutes=i), + 'minute'), + 'inbound': {'realtime': {'messages': {'count': i}}} + } + ) + # asynctest does not support setUpClass method + if TestRestAppStatsSetup.__stats_added: + return + self.ably.http.post('/stats', body=stats + previous_stats) + TestRestAppStatsSetup.__stats_added = True + + def tearDown(self): + self.ably.close() + self.ably_text.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + + +class TestDirectionForwards(TestRestAppStatsSetup, BaseAsyncTestCase, + metaclass=VaryByProtocolTestsMetaclass): + + def get_params(self): + return { + 'start': self.last_interval - timedelta(minutes=2), + 'end': self.last_interval, + 'unit': 'minute', + 'direction': 'forwards', + 'limit': 1 + } + + def test_stats_are_forward(self): + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.inbound.realtime.all.count"] == 50 + + def test_three_pages(self): + stats_pages = self.ably.stats(**self.get_params()) + assert not stats_pages.is_last() + page2 = stats_pages.next() + page3 = page2.next() + assert page3.items[0].entries["messages.inbound.realtime.all.count"] == 70 + + +class TestDirectionBackwards(TestRestAppStatsSetup, BaseAsyncTestCase, + metaclass=VaryByProtocolTestsMetaclass): + + def get_params(self): + return { + 'end': self.last_interval, + 'unit': 'minute', + 'direction': 'backwards', + 'limit': 1 + } + + def test_stats_are_forward(self): + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.inbound.realtime.all.count"] == 70 + + def test_three_pages(self): + stats_pages = self.ably.stats(**self.get_params()) + assert not stats_pages.is_last() + page2 = stats_pages.next() + page3 = page2.next() + assert not stats_pages.is_last() + assert page3.items[0].entries["messages.inbound.realtime.all.count"] == 50 + + +class TestOnlyLastYear(TestRestAppStatsSetup, BaseAsyncTestCase, + metaclass=VaryByProtocolTestsMetaclass): + + def get_params(self): + return { + 'end': self.last_interval, + 'unit': 'minute', + 'limit': 3 + } + + def test_default_is_backwards(self): + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + assert stats[0].entries["messages.inbound.realtime.messages.count"] == 70 + assert stats[-1].entries["messages.inbound.realtime.messages.count"] == 50 + + +class TestPreviousYear(TestRestAppStatsSetup, BaseAsyncTestCase, + metaclass=VaryByProtocolTestsMetaclass): + + def get_params(self): + return { + 'end': self.previous_interval, + 'unit': 'minute', + } + + def test_default_100_pagination(self): + self.stats_pages = self.ably.stats(**self.get_params()) + stats = self.stats_pages.items + assert len(stats) == 100 + next_page = self.stats_pages.next() + assert len(next_page.items) == 20 + + +class TestRestAppStats(TestRestAppStatsSetup, BaseAsyncTestCase, + metaclass=VaryByProtocolTestsMetaclass): + + @dont_vary_protocol + def test_protocols(self): + stats_pages = self.ably.stats(**self.get_params()) + stats_pages1 = self.ably_text.stats(**self.get_params()) + assert len(stats_pages.items) == len(stats_pages1.items) + + def test_paginated_response(self): + stats_pages = self.ably.stats(**self.get_params()) + assert isinstance(stats_pages, PaginatedResult) + assert isinstance(stats_pages.items[0], Stats) + + def test_units(self): + for unit in ['hour', 'day', 'month']: + params = { + 'start': self.last_interval, + 'end': self.last_interval, + 'unit': unit, + 'direction': 'forwards', + 'limit': 1 + } + stats_pages = self.ably.stats(**params) + stat = stats_pages.items[0] + assert len(stats_pages.items) == 1 + assert stat.entries["messages.all.messages.count"] == 50 + 20 + 60 + 10 + 70 + 40 + assert stat.entries["messages.all.messages.data"] == 5000 + 2000 + 6000 + 1000 + 7000 + 4000 + + @dont_vary_protocol + def test_when_argument_start_is_after_end(self): + params = { + 'start': self.last_interval, + 'end': self.last_interval - timedelta(minutes=2), + 'unit': 'minute', + } + with pytest.raises(AblyException, match="'end' parameter has to be greater than or equal to 'start'"): + self.ably.stats(**params) + + @dont_vary_protocol + def test_when_limit_gt_1000(self): + params = { + 'end': self.last_interval, + 'limit': 5000 + } + with pytest.raises(AblyException, match="The maximum allowed limit is 1000"): + self.ably.stats(**params) + + def test_no_arguments(self): + params = { + 'end': self.last_interval, + } + stats_pages = self.ably.stats(**params) + self.stat = stats_pages.items[0] + assert self.stat.unit == 'minute' + + def test_got_1_record(self): + stats_pages = self.ably.stats(**self.get_params()) + assert 1 == len(stats_pages.items), "Expected 1 record" + + def test_return_aggregated_message_data(self): + # returns aggregated message data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.all.messages.count"] == 70 + 40 + assert stat.entries["messages.all.messages.data"] == 7000 + 4000 + + def test_inbound_realtime_all_data(self): + # returns inbound realtime all data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.inbound.realtime.all.count"] == 70 + assert stat.entries["messages.inbound.realtime.all.data"] == 7000 + + def test_inboud_realtime_message_data(self): + # returns inbound realtime message data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.inbound.realtime.messages.count"] == 70 + assert stat.entries["messages.inbound.realtime.messages.data"] == 7000 + + def test_outbound_realtime_all_data(self): + # returns outboud realtime all data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.outbound.realtime.all.count"] == 40 + assert stat.entries["messages.outbound.realtime.all.data"] == 4000 + + def test_persisted_data(self): + # returns persisted presence all data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["messages.persisted.all.count"] == 20 + assert stat.entries["messages.persisted.all.data"] == 2000 + + def test_connections_data(self): + # returns connections all data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["connections.all.peak"] == 20 + assert stat.entries["connections.all.opened"] == 10 + + def test_channels_all_data(self): + # returns channels all data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["channels.peak"] == 50 + assert stat.entries["channels.opened"] == 30 + + def test_api_requests_data(self): + # returns api_requests data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["apiRequests.other.succeeded"] == 50 + assert stat.entries["apiRequests.other.failed"] == 10 + + def test_token_requests(self): + # returns token_requests data + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.entries["apiRequests.tokenRequests.succeeded"] == 60 + assert stat.entries["apiRequests.tokenRequests.failed"] == 20 + + def test_interval(self): + # interval + stats_pages = self.ably.stats(**self.get_params()) + stats = stats_pages.items + stat = stats[0] + assert stat.unit == 'minute' + assert stat.interval_id == self.last_interval.strftime('%Y-%m-%d:%H:%M') + assert stat.interval_time == self.last_interval diff --git a/test/ably/sync/rest/resttime_test.py b/test/ably/sync/rest/resttime_test.py new file mode 100644 index 00000000..70116864 --- /dev/null +++ b/test/ably/sync/rest/resttime_test.py @@ -0,0 +1,43 @@ +import time + +import pytest + +from ably.sync import AblyException + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + + +class TestRestTime(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def setUp(self): + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def test_time_accuracy(self): + reported_time = self.ably.time() + actual_time = time.time() * 1000.0 + + seconds = 10 + assert abs(actual_time - reported_time) < seconds * 1000, "Time is not within %s seconds" % seconds + + def test_time_without_key_or_token(self): + reported_time = self.ably.time() + actual_time = time.time() * 1000.0 + + seconds = 10 + assert abs(actual_time - reported_time) < seconds * 1000, "Time is not within %s seconds" % seconds + + @dont_vary_protocol + def test_time_fails_without_valid_host(self): + ably = TestApp.get_ably_rest(key=None, token='foo', rest_host="this.host.does.not.exist") + with pytest.raises(AblyException): + ably.time() + + ably.close() diff --git a/test/ably/sync/rest/resttoken_test.py b/test/ably/sync/rest/resttoken_test.py new file mode 100644 index 00000000..f43bcbd8 --- /dev/null +++ b/test/ably/sync/rest/resttoken_test.py @@ -0,0 +1,342 @@ +import datetime +import json +import logging + +from mock import patch +import pytest + +from ably.sync import AblyException +from ably.sync import AblyRest +from ably.sync import Capability +from ably.sync.types.tokendetails import TokenDetails +from ably.sync.types.tokenrequest import TokenRequest + +from test.ably.sync.testapp import TestApp +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +class TestRestToken(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def server_time(self): + return self.ably.time() + + def setUp(self): + capability = {"*": ["*"]} + self.permit_all = str(Capability(capability)) + self.ably = TestApp.get_ably_rest() + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + def test_request_token_null_params(self): + pre_time = self.server_time() + token_details = self.ably.auth.request_token() + post_time = self.server_time() + assert token_details.token is not None, "Expected token" + assert token_details.issued + 300 >= pre_time, "Unexpected issued time" + assert token_details.issued <= post_time, "Unexpected issued time" + assert self.permit_all == str(token_details.capability), "Unexpected capability" + + def test_request_token_explicit_timestamp(self): + pre_time = self.server_time() + token_details = self.ably.auth.request_token(token_params={'timestamp': pre_time}) + post_time = self.server_time() + assert token_details.token is not None, "Expected token" + assert token_details.issued + 300 >= pre_time, "Unexpected issued time" + assert token_details.issued <= post_time, "Unexpected issued time" + assert self.permit_all == str(Capability(token_details.capability)), "Unexpected Capability" + + def test_request_token_explicit_invalid_timestamp(self): + request_time = self.server_time() + explicit_timestamp = request_time - 30 * 60 * 1000 + + with pytest.raises(AblyException): + self.ably.auth.request_token(token_params={'timestamp': explicit_timestamp}) + + def test_request_token_with_system_timestamp(self): + pre_time = self.server_time() + token_details = self.ably.auth.request_token(query_time=True) + post_time = self.server_time() + assert token_details.token is not None, "Expected token" + assert token_details.issued >= pre_time, "Unexpected issued time" + assert token_details.issued <= post_time, "Unexpected issued time" + assert self.permit_all == str(Capability(token_details.capability)), "Unexpected Capability" + + def test_request_token_with_duplicate_nonce(self): + request_time = self.server_time() + token_params = { + 'timestamp': request_time, + 'nonce': '1234567890123456' + } + token_details = self.ably.auth.request_token(token_params) + assert token_details.token is not None, "Expected token" + + with pytest.raises(AblyException): + self.ably.auth.request_token(token_params) + + def test_request_token_with_capability_that_subsets_key_capability(self): + capability = Capability({ + "onlythischannel": ["subscribe"] + }) + + token_details = self.ably.auth.request_token( + token_params={'capability': capability}) + + assert token_details is not None + assert token_details.token is not None + assert capability == token_details.capability, "Unexpected capability" + + def test_request_token_with_specified_key(self): + test_vars = TestApp.get_test_vars() + key = test_vars["keys"][1] + token_details = self.ably.auth.request_token( + key_name=key["key_name"], key_secret=key["key_secret"]) + assert token_details.token is not None, "Expected token" + assert key.get("capability") == token_details.capability, "Unexpected capability" + + @dont_vary_protocol + def test_request_token_with_invalid_mac(self): + with pytest.raises(AblyException): + self.ably.auth.request_token(token_params={'mac': "thisisnotavalidmac"}) + + def test_request_token_with_specified_ttl(self): + token_details = self.ably.auth.request_token(token_params={'ttl': 100}) + assert token_details.token is not None, "Expected token" + assert token_details.issued + 100 == token_details.expires, "Unexpected expires" + + @dont_vary_protocol + def test_token_with_excessive_ttl(self): + excessive_ttl = 365 * 24 * 60 * 60 * 1000 + with pytest.raises(AblyException): + self.ably.auth.request_token(token_params={'ttl': excessive_ttl}) + + @dont_vary_protocol + def test_token_generation_with_invalid_ttl(self): + with pytest.raises(AblyException): + self.ably.auth.request_token(token_params={'ttl': -1}) + + def test_token_generation_with_local_time(self): + timestamp = self.ably.auth._timestamp + with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + self.ably.auth.request_token() + assert local_time.called + assert not server_time.called + + # RSA10k + def test_token_generation_with_server_time(self): + timestamp = self.ably.auth._timestamp + with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + self.ably.auth.request_token(query_time=True) + assert local_time.call_count == 1 + assert server_time.call_count == 1 + self.ably.auth.request_token(query_time=True) + assert local_time.call_count == 2 + assert server_time.call_count == 1 + + # TD7 + def test_toke_details_from_json(self): + token_details = self.ably.auth.request_token() + token_details_dict = token_details.to_dict() + token_details_str = json.dumps(token_details_dict) + + assert token_details == TokenDetails.from_json(token_details_dict) + assert token_details == TokenDetails.from_json(token_details_str) + + # Issue #71 + @dont_vary_protocol + def test_request_token_float_and_timedelta(self): + lifetime = datetime.timedelta(hours=4) + self.ably.auth.request_token({'ttl': lifetime.total_seconds() * 1000}) + self.ably.auth.request_token({'ttl': lifetime}) + + +class TestCreateTokenRequest(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + def setUp(self): + self.ably = TestApp.get_ably_rest() + self.key_name = self.ably.options.key_name + self.key_secret = self.ably.options.key_secret + + def tearDown(self): + self.ably.close() + + def per_protocol_setup(self, use_binary_protocol): + self.ably.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + @dont_vary_protocol + def test_key_name_and_secret_are_required(self): + ably = TestApp.get_ably_rest(key=None, token='not a real token') + with pytest.raises(AblyException, match="40101 401 No key specified"): + ably.auth.create_token_request() + with pytest.raises(AblyException, match="40101 401 No key specified"): + ably.auth.create_token_request(key_name=self.key_name) + with pytest.raises(AblyException, match="40101 401 No key specified"): + ably.auth.create_token_request(key_secret=self.key_secret) + + @dont_vary_protocol + def test_with_local_time(self): + timestamp = self.ably.auth._timestamp + with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, query_time=False) + assert local_time.called + assert not server_time.called + + # RSA10k + @dont_vary_protocol + def test_with_server_time(self): + timestamp = self.ably.auth._timestamp + with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, query_time=True) + assert local_time.call_count == 1 + assert server_time.call_count == 1 + self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, query_time=True) + assert local_time.call_count == 2 + assert server_time.call_count == 1 + + def test_token_request_can_be_used_to_get_a_token(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert isinstance(token_request, TokenRequest) + + def auth_callback(token_params): + return token_request + + ably = TestApp.get_ably_rest(key=None, + auth_callback=auth_callback, + use_binary_protocol=self.use_binary_protocol) + + token = ably.auth.authorize() + assert isinstance(token, TokenDetails) + ably.close() + + def test_token_request_dict_can_be_used_to_get_a_token(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert isinstance(token_request, TokenRequest) + + def auth_callback(token_params): + return token_request.to_dict() + + ably = TestApp.get_ably_rest(key=None, + auth_callback=auth_callback, + use_binary_protocol=self.use_binary_protocol) + + token = ably.auth.authorize() + assert isinstance(token, TokenDetails) + ably.close() + + # TE6 + @dont_vary_protocol + def test_token_request_from_json(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert isinstance(token_request, TokenRequest) + + token_request_dict = token_request.to_dict() + assert token_request == TokenRequest.from_json(token_request_dict) + + token_request_str = json.dumps(token_request_dict) + assert token_request == TokenRequest.from_json(token_request_str) + + @dont_vary_protocol + def test_nonce_is_random_and_longer_than_15_characters(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert len(token_request.nonce) > 15 + + another_token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert len(another_token_request.nonce) > 15 + + assert token_request.nonce != another_token_request.nonce + + # RSA5 + @dont_vary_protocol + def test_ttl_is_optional_and_specified_in_ms(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert token_request.ttl is None + + # RSA6 + @dont_vary_protocol + def test_capability_is_optional(self): + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret) + assert token_request.capability is None + + @dont_vary_protocol + def test_accept_all_token_params(self): + token_params = { + 'ttl': 1000, + 'capability': Capability({'channel': ['publish']}), + 'client_id': 'a_id', + 'timestamp': 1000, + 'nonce': 'a_nonce', + } + token_request = self.ably.auth.create_token_request( + token_params, + key_name=self.key_name, key_secret=self.key_secret, + ) + assert token_request.ttl == token_params['ttl'] + assert token_request.capability == str(token_params['capability']) + assert token_request.client_id == token_params['client_id'] + assert token_request.timestamp == token_params['timestamp'] + assert token_request.nonce == token_params['nonce'] + + def test_capability(self): + capability = Capability({'channel': ['publish']}) + token_request = self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, + token_params={'capability': capability}) + assert token_request.capability == str(capability) + + def auth_callback(token_params): + return token_request + + ably = TestApp.get_ably_rest(key=None, auth_callback=auth_callback, + use_binary_protocol=self.use_binary_protocol) + + token = ably.auth.authorize() + + assert str(token.capability) == str(capability) + ably.close() + + @dont_vary_protocol + def test_hmac(self): + ably = AblyRest(key_name='a_key_name', key_secret='a_secret') + token_params = { + 'ttl': 1000, + 'nonce': 'abcde100', + 'client_id': 'a_id', + 'timestamp': 1000, + } + token_request = ably.auth.create_token_request( + token_params, key_secret='a_secret', key_name='a_key_name') + assert token_request.mac == 'sYkCH0Un+WgzI7/Nhy0BoQIKq9HmjKynCRs4E3qAbGQ=' + ably.close() + + # AO2g + @dont_vary_protocol + def test_query_server_time(self): + with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time: + self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, query_time=True) + assert server_time.call_count == 1 + + self.ably.auth.create_token_request( + key_name=self.key_name, key_secret=self.key_secret, query_time=False) + assert server_time.call_count == 1 diff --git a/test/ably/sync/testapp.py b/test/ably/sync/testapp.py new file mode 100644 index 00000000..54c0af02 --- /dev/null +++ b/test/ably/sync/testapp.py @@ -0,0 +1,115 @@ +import json +import os +import logging + +from ably.sync.rest.rest import AblyRest +from ably.sync.types.capability import Capability +from ably.sync.types.options import Options +from ably.sync.util.exceptions import AblyException +from ably.sync.realtime.realtime import AblyRealtime + +log = logging.getLogger(__name__) + +with open(os.path.dirname(__file__) + '/../../assets/testAppSpec.json', 'r') as f: + app_spec_local = json.loads(f.read()) + +tls = (os.environ.get('ABLY_TLS') or "true").lower() == "true" +rest_host = os.environ.get('ABLY_REST_HOST', 'sandbox-rest.ably.io') +realtime_host = os.environ.get('ABLY_REALTIME_HOST', 'sandbox-realtime.ably.io') + +environment = os.environ.get('ABLY_ENV', 'sandbox') + +port = 80 +tls_port = 443 + +if rest_host and not rest_host.endswith("rest.ably.io"): + tls = tls and rest_host != "localhost" + port = 8080 + tls_port = 8081 + + +ably = AblyRest(token='not_a_real_token', + port=port, tls_port=tls_port, tls=tls, + environment=environment, + use_binary_protocol=False) + + +class TestApp: + __test_vars = None + + @staticmethod + def get_test_vars(): + if not TestApp.__test_vars: + r = ably.http.post("/apps", body=app_spec_local, skip_auth=True) + AblyException.raise_for_response(r) + + app_spec = r.json() + + app_id = app_spec.get("appId", "") + + test_vars = { + "app_id": app_id, + "host": rest_host, + "port": port, + "tls_port": tls_port, + "tls": tls, + "environment": environment, + "realtime_host": realtime_host, + "keys": [{ + "key_name": "%s.%s" % (app_id, k.get("id", "")), + "key_secret": k.get("value", ""), + "key_str": "%s.%s:%s" % (app_id, k.get("id", ""), k.get("value", "")), + "capability": Capability(json.loads(k.get("capability", "{}"))), + } for k in app_spec.get("keys", [])] + } + + TestApp.__test_vars = test_vars + log.debug([(app_id, k.get("id", ""), k.get("value", "")) + for k in app_spec.get("keys", [])]) + + return TestApp.__test_vars + + @staticmethod + def get_ably_rest(**kw): + test_vars = TestApp.get_test_vars() + options = TestApp.get_options(test_vars, **kw) + options.update(kw) + return AblyRest(**options) + + @staticmethod + def get_ably_realtime(**kw): + test_vars = TestApp.get_test_vars() + options = TestApp.get_options(test_vars, **kw) + return AblyRealtime(**options) + + @staticmethod + def get_options(test_vars, **kwargs): + options = { + 'port': test_vars["port"], + 'tls_port': test_vars["tls_port"], + 'tls': test_vars["tls"], + 'environment': test_vars["environment"], + } + auth_methods = ["auth_url", "auth_callback", "token", "token_details", "key"] + if not any(x in kwargs for x in auth_methods): + options["key"] = test_vars["keys"][0]["key_str"] + + if any(x in kwargs for x in ["rest_host", "realtime_host"]): + options["environment"] = None + + options.update(kwargs) + + return options + + @staticmethod + def clear_test_vars(): + test_vars = TestApp.__test_vars + options = Options(key=test_vars["keys"][0]["key_str"]) + options.rest_host = test_vars["host"] + options.port = test_vars["port"] + options.tls_port = test_vars["tls_port"] + options.tls = test_vars["tls"] + ably = TestApp.get_ably_rest() + ably.http.delete('/apps/' + test_vars['app_id']) + TestApp.__test_vars = None + ably.close() diff --git a/test/ably/sync/utils.py b/test/ably/sync/utils.py new file mode 100644 index 00000000..c3d68f79 --- /dev/null +++ b/test/ably/sync/utils.py @@ -0,0 +1,168 @@ +import functools +import random +import string +import unittest +import sys +if sys.version_info >= (3, 8): + from unittest import IsolatedAsyncioTestCase +else: + from async_case import IsolatedAsyncioTestCase + +import msgpack +import mock +import respx +from httpx import Response + +from ably.sync.http.http import Http + + +class BaseTestCase(unittest.TestCase): + + def respx_add_empty_msg_pack(self, url, method='GET'): + respx.route(method=method, url=url).return_value = Response( + status_code=200, + headers={'content-type': 'application/x-msgpack'}, + content=msgpack.packb({}) + ) + + @classmethod + def get_channel_name(cls, prefix=''): + return prefix + random_string(10) + + @classmethod + def get_channel(cls, prefix=''): + name = cls.get_channel_name(prefix) + return cls.ably.channels.get(name) + + +class BaseAsyncTestCase(IsolatedAsyncioTestCase): + + def respx_add_empty_msg_pack(self, url, method='GET'): + respx.route(method=method, url=url).return_value = Response( + status_code=200, + headers={'content-type': 'application/x-msgpack'}, + content=msgpack.packb({}) + ) + + @classmethod + def get_channel_name(cls, prefix=''): + return prefix + random_string(10) + + def get_channel(self, prefix=''): + name = self.get_channel_name(prefix) + return self.ably.channels.get(name) + + +def assert_responses_type(protocol): + """ + This is a decorator to check if we retrieved responses with the correct protocol. + usage: + + @assert_responses_type('json') + def test_something(self): + ... + + this will check if all responses received during the test will be in the format + json. + supports json and msgpack + """ + responses = [] + + def patch(): + original = Http.make_request + + def fake_make_request(self, *args, **kwargs): + response = original(self, *args, **kwargs) + responses.append(response) + return response + + patcher = mock.patch.object(Http, 'make_request', fake_make_request) + patcher.start() + return patcher + + def unpatch(patcher): + patcher.stop() + + def test_decorator(fn): + @functools.wraps(fn) + def test_decorated(self, *args, **kwargs): + patcher = patch() + fn(self, *args, **kwargs) + unpatch(patcher) + + assert len(responses) >= 1,\ + "If your test doesn't make any requests, use the @dont_vary_protocol decorator" + + for response in responses: + # In HTTP/2 some header fields are optional in case of 204 status code + if protocol == 'json': + if response.status_code != 204: + assert response.headers['content-type'] == 'application/json' + if response.content: + response.json() + else: + if response.status_code != 204: + assert response.headers['content-type'] == 'application/x-msgpack' + if response.content: + msgpack.unpackb(response.content) + + return test_decorated + return test_decorator + + +class VaryByProtocolTestsMetaclass(type): + """ + Metaclass to run tests in more than one protocol. + Usage: + * set this as metaclass of the TestCase class + * create the following method: + def per_protocol_setup(self, use_binary_protocol): + # do something here that will run before each test. + * now every test will run twice and before test is run per_protocol_setup + is called + * exclude tests with the @dont_vary_protocol decorator + """ + def __new__(cls, clsname, bases, dct): + for key, value in tuple(dct.items()): + if key.startswith('test') and not getattr(value, 'dont_vary_protocol', + False): + + wrapper_bin = cls.wrap_as('bin', key, value) + wrapper_text = cls.wrap_as('text', key, value) + + dct[key + '_bin'] = wrapper_bin + dct[key + '_text'] = wrapper_text + del dct[key] + + return super().__new__(cls, clsname, bases, dct) + + @staticmethod + def wrap_as(ttype, old_name, old_func): + expected_content = {'bin': 'msgpack', 'text': 'json'} + + @assert_responses_type(expected_content[ttype]) + def wrapper(self): + if hasattr(self, 'per_protocol_setup'): + self.per_protocol_setup(ttype == 'bin') + old_func(self) + wrapper.__name__ = old_name + '_' + ttype + return wrapper + + +def dont_vary_protocol(func): + func.dont_vary_protocol = True + return func + + +def random_string(length, alphabet=string.ascii_letters): + return ''.join([random.choice(alphabet) for x in range(length)]) + + +def new_dict(src, **kw): + new = src.copy() + new.update(kw) + return new + + +def get_random_key(d): + return random.choice(list(d)) From b4fa9f438a424f18b0d1c3a7be0d2a518eb093c4 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:10:37 +0530 Subject: [PATCH 15/52] Added auto indentation code to unasync file --- unasync.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/unasync.py b/unasync.py index 73a70651..d18cb0a0 100644 --- a/unasync.py +++ b/unasync.py @@ -74,12 +74,37 @@ def _unasync_file(self, filepath): def _unasync_tokens(self, tokens: list): new_tokens = [] token_counter = 0 + async_await_block_started = False + async_await_offset = 0 while token_counter < len(tokens): token = tokens[token_counter] + if async_await_block_started: + if token.src == '\n': + new_tokens.append(token) + token_counter = token_counter + 1 + next_newline_token = tokens[token_counter] + if len(next_newline_token.src) >= 6 and tokens[token_counter+1].utf8_byte_offset > async_await_offset: + new_tab_indentation = next_newline_token.src[:-6] # remove last 6 white spaces + next_newline_token = next_newline_token._replace(src=new_tab_indentation) + new_tokens.append(next_newline_token) + else: + new_tokens.append(next_newline_token) + token_counter = token_counter + 1 + continue + + if token.src == ')': + async_await_block_started = False + async_await_offset = 0 + if token.src in ["async", "await"]: # When removing async or await, we want to skip the following whitespace token_counter = token_counter + 2 + if (tokens[token_counter].src == 'def' or tokens[token_counter + 1].src == '(' or + tokens[token_counter + 2].src == '(' or tokens[token_counter + 3].src == "("): + # Fix indentation issues for async/await fn definition/call + async_await_offset = token.utf8_byte_offset + async_await_block_started = True continue elif token.name == "NAME": if token.src == "from": From 3a51a5370b05c07cc8c46d6a98af6ad4ea4416d1 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:11:40 +0530 Subject: [PATCH 16/52] Fixed indentation based on new formula --- ably/sync/http/http.py | 12 ++++++------ ably/sync/http/paginatedresult.py | 6 +++--- ably/sync/realtime/connectionmanager.py | 2 +- ably/sync/rest/auth.py | 14 +++++++------- ably/sync/rest/push.py | 4 ++-- ably/sync/rest/rest.py | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/ably/sync/http/http.py b/ably/sync/http/http.py index 8e52da55..3fcba89b 100644 --- a/ably/sync/http/http.py +++ b/ably/sync/http/http.py @@ -158,7 +158,7 @@ def get_rest_hosts(self): @reauth_if_expired def make_request(self, method, path, version=None, headers=None, body=None, - skip_auth=False, timeout=None, raise_on_error=True): + skip_auth=False, timeout=None, raise_on_error=True): if body is not None and type(body) not in (bytes, str): body = self.dump_body(body) @@ -229,27 +229,27 @@ def make_request(self, method, path, version=None, headers=None, body=None, def delete(self, url, headers=None, skip_auth=False, timeout=None): result = self.make_request('DELETE', url, headers=headers, - skip_auth=skip_auth, timeout=timeout) + skip_auth=skip_auth, timeout=timeout) return result def get(self, url, headers=None, skip_auth=False, timeout=None): result = self.make_request('GET', url, headers=headers, - skip_auth=skip_auth, timeout=timeout) + skip_auth=skip_auth, timeout=timeout) return result def patch(self, url, headers=None, body=None, skip_auth=False, timeout=None): result = self.make_request('PATCH', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) + skip_auth=skip_auth, timeout=timeout) return result def post(self, url, headers=None, body=None, skip_auth=False, timeout=None): result = self.make_request('POST', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) + skip_auth=skip_auth, timeout=timeout) return result def put(self, url, headers=None, body=None, skip_auth=False, timeout=None): result = self.make_request('PUT', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) + skip_auth=skip_auth, timeout=timeout) return result @property diff --git a/ably/sync/http/paginatedresult.py b/ably/sync/http/paginatedresult.py index 8dbc78ec..4f47075a 100644 --- a/ably/sync/http/paginatedresult.py +++ b/ably/sync/http/paginatedresult.py @@ -78,8 +78,8 @@ def __get_rel(self, rel_req): @classmethod def paginated_query(cls, http, method='GET', url='/', version=None, body=None, - headers=None, response_processor=None, - raise_on_error=True): + headers=None, response_processor=None, + raise_on_error=True): headers = headers or {} req = Request(method, url, version=version, body=body, headers=headers, skip_auth=False, raise_on_error=raise_on_error) @@ -87,7 +87,7 @@ def paginated_query(cls, http, method='GET', url='/', version=None, body=None, @classmethod def paginated_query_with_request(cls, http, request, response_processor, - raise_on_error=True): + raise_on_error=True): response = http.make_request( request.method, request.url, version=request.version, headers=request.headers, body=request.body, diff --git a/ably/sync/realtime/connectionmanager.py b/ably/sync/realtime/connectionmanager.py index 0be5a427..7e5fd820 100644 --- a/ably/sync/realtime/connectionmanager.py +++ b/ably/sync/realtime/connectionmanager.py @@ -130,7 +130,7 @@ def ping(self) -> float: self.__ping_id = get_random_id() ping_start_time = datetime.now().timestamp() self.send_protocol_message({"action": ProtocolMessageAction.HEARTBEAT, - "id": self.__ping_id}) + "id": self.__ping_id}) else: raise AblyException("Cannot send ping request. Calling ping in invalid state", 40000, 400) try: diff --git a/ably/sync/rest/auth.py b/ably/sync/rest/auth.py index a35e1fc2..e310b550 100644 --- a/ably/sync/rest/auth.py +++ b/ably/sync/rest/auth.py @@ -152,11 +152,11 @@ def authorize(self, token_params: Optional[dict] = None, auth_options=None): return self.__authorize_when_necessary(token_params, auth_options, force=True) def request_token(self, token_params: Optional[dict] = None, - # auth_options - key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None, - auth_url: Optional[str] = None, auth_method: Optional[str] = None, - auth_headers: Optional[dict] = None, auth_params: Optional[dict] = None, - query_time=None): + # auth_options + key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None, + auth_url: Optional[str] = None, auth_method: Optional[str] = None, + auth_headers: Optional[dict] = None, auth_params: Optional[dict] = None, + query_time=None): token_params = token_params or {} token_params = dict(self.auth_options.default_token_params, **token_params) @@ -230,7 +230,7 @@ def request_token(self, token_params: Optional[dict] = None, return TokenDetails.from_dict(response_dict) def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None, - key_secret: Optional[str] = None, query_time=None): + key_secret: Optional[str] = None, query_time=None): token_params = token_params or {} token_request = {} @@ -387,7 +387,7 @@ def _random_nonce(self): return uuid.uuid4().hex[:16] def token_request_from_auth_url(self, method: str, url: str, token_params, - headers, auth_params): + headers, auth_params): body = None params = None if method == 'GET': diff --git a/ably/sync/rest/push.py b/ably/sync/rest/push.py index fabb2c1a..6133f85f 100644 --- a/ably/sync/rest/push.py +++ b/ably/sync/rest/push.py @@ -142,7 +142,7 @@ def list(self, **params): """ path = '/push/channelSubscriptions' + format_params(params) return PaginatedResult.paginated_query(self.ably.http, url=path, - response_processor=channel_subscriptions_response_processor) + response_processor=channel_subscriptions_response_processor) def list_channels(self, **params): """Returns a PaginatedResult object with the list of @@ -153,7 +153,7 @@ def list_channels(self, **params): """ path = '/push/channels' + format_params(params) return PaginatedResult.paginated_query(self.ably.http, url=path, - response_processor=channels_response_processor) + response_processor=channels_response_processor) def save(self, subscription: dict): """Creates or updates the subscription. Returns a diff --git a/ably/sync/rest/rest.py b/ably/sync/rest/rest.py index ff163967..56cc3723 100644 --- a/ably/sync/rest/rest.py +++ b/ably/sync/rest/rest.py @@ -80,7 +80,7 @@ def __enter__(self): @catch_all def stats(self, direction: Optional[str] = None, start=None, end=None, params: Optional[dict] = None, - limit: Optional[int] = None, paginated=None, unit=None, timeout=None): + limit: Optional[int] = None, paginated=None, unit=None, timeout=None): """Returns the stats for this application""" formatted_params = format_params(params, direction=direction, start=start, end=end, limit=limit, unit=unit) url = '/stats' + formatted_params @@ -120,7 +120,7 @@ def push(self): return self.__push def request(self, method: str, path: str, version: str, params: - Optional[dict] = None, body=None, headers=None): + Optional[dict] = None, body=None, headers=None): if version is None: raise AblyException("No version parameter", 400, 40000) From b89af9518244252163a81cea2d0698827d732f9f Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:24:44 +0530 Subject: [PATCH 17/52] Merged unasync_test into unasync, refactored code --- unasync.py | 89 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 35 deletions(-) diff --git a/unasync.py b/unasync.py index d18cb0a0..9d6d2d76 100644 --- a/unasync.py +++ b/unasync.py @@ -29,6 +29,10 @@ } +_STRING_REPLACE = { +} + + class Rule: """A single set of rules for 'unasync'ing file(s)""" @@ -80,6 +84,7 @@ def _unasync_tokens(self, tokens: list): token = tokens[token_counter] if async_await_block_started: + # Fix indentation issues for async/await fn definition/call if token.src == '\n': new_tokens.append(token) token_counter = token_counter + 1 @@ -106,6 +111,7 @@ def _unasync_tokens(self, tokens: list): async_await_offset = token.utf8_byte_offset async_await_block_started = True continue + elif token.name == "NAME": if token.src == "from": if tokens[token_counter + 1].src == " ": @@ -114,44 +120,16 @@ def _unasync_tokens(self, tokens: list): else: token = token._replace(src=self._unasync_name(token.src)) elif token.name == "STRING": - left_quote, name, right_quote = ( - token.src[0], - token.src[1:-1], - token.src[-1], - ) - token = token._replace( - src=left_quote + self._unasync_name(name) + right_quote - ) + src_token = token.src.replace("'", "") + if _STRING_REPLACE.get(src_token) is not None: + new_token = f"'{_STRING_REPLACE[src_token]}'" + token = token._replace(src=new_token) new_tokens.append(token) token_counter = token_counter + 1 return new_tokens - # for i, token in enumerate(tokens): - # if skip_next: - # skip_next = False - # continue - # - # if token.src in ["async", "await"]: - # # When removing async or await, we want to skip the following whitespace - # # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` - # skip_next = True - # else: - # if token.name == "NAME": - # token = token._replace(src=self._unasync_name(token.src)) - # elif token.name == "STRING": - # left_quote, name, right_quote = ( - # token.src[0], - # token.src[1:-1], - # token.src[-1], - # ) - # token = token._replace( - # src=left_quote + self._unasync_name(name) + right_quote - # ) - # - # yield token - def _replace_import(self, tokens, token_counter, new_tokens: list): new_tokens.append(tokens[token_counter]) new_tokens.append(tokens[token_counter + 1]) @@ -182,9 +160,6 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): def _unasync_name(self, name): if name in self.token_replacements: return self.token_replacements[name] - # Convert classes prefixed with 'Async' into 'Sync' - # elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): - # return "Sync" + name[5:] return name @@ -207,6 +182,8 @@ def unasync_files(fpath_list, rules): Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) +# Source files ========================================== + src_dir_path = os.path.join(os.getcwd(), "ably") dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") _DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) @@ -222,3 +199,45 @@ def find_files(dir_path, file_name_regex) -> list[str]: set(find_files(dest_dir_path, "*.py"))) unasync_files(list(relevant_src_files), (_DEFAULT_RULE,)) + +# Test files ============================================== + + +_ASYNC_TO_SYNC["AsyncClient"] = "Client" +_ASYNC_TO_SYNC["aclose"] = "close" +_ASYNC_TO_SYNC["asyncSetUp"] = "setUp" +_ASYNC_TO_SYNC["asyncTearDown"] = "tearDown" +_ASYNC_TO_SYNC["AsyncMock"] = "Mock" + +_IMPORTS_REPLACE["ably"] = "ably.sync" +_IMPORTS_REPLACE["test.ably"] = "test.ably.sync" + +_STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' +_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.Auth.request_token' +_STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' + +Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) + +src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +os.makedirs(dest_dir_path, exist_ok=True) + + +def find_files(dir_path, file_name_regex) -> list[str]: + return glob.glob(os.path.join(dir_path, file_name_regex), recursive=True) + + +src_files = find_files(src_dir_path, "*.py") +unasync_files(src_files, (_DEFAULT_RULE,)) + +# round 2 +src_dir_path = os.path.join(os.getcwd(), "test", "ably") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) + +src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), + os.path.join(os.getcwd(), "test", "ably", "utils.py")] + +unasync_files(src_files, (_DEFAULT_RULE,)) From 7d24da3a777471ac690ae205d13502926a0c6af7 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:25:36 +0530 Subject: [PATCH 18/52] Executed updated unasync test for indentation --- test/ably/sync/rest/restauth_test.py | 8 ++++---- test/ably/sync/rest/restchannelhistory_test.py | 4 ++-- test/ably/sync/rest/restchannelpublish_test.py | 4 ++-- test/ably/sync/rest/restinit_test.py | 2 +- test/ably/sync/rest/resttoken_test.py | 10 +++++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/test/ably/sync/rest/restauth_test.py b/test/ably/sync/rest/restauth_test.py index 4ca85f45..7f601156 100644 --- a/test/ably/sync/rest/restauth_test.py +++ b/test/ably/sync/rest/restauth_test.py @@ -224,7 +224,7 @@ def test_with_token_str_https(self): token = self.ably.auth.authorize() token = token.token ably = TestApp.get_ably_rest(key=None, token=token, tls=True, - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') ably.close() @@ -232,13 +232,13 @@ def test_with_token_str_http(self): token = self.ably.auth.authorize() token = token.token ably = TestApp.get_ably_rest(key=None, token=token, tls=False, - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') ably.close() def test_if_default_client_id_is_used(self): ably = TestApp.get_ably_rest(client_id='my_client_id', - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) token = ably.auth.authorize() assert token.client_id == 'my_client_id' ably.close() @@ -335,7 +335,7 @@ def test_with_key(self): ably.close() ably = TestApp.get_ably_rest(key=None, token_details=token_details, - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) channel = self.get_channel_name('test_request_token_with_key') ably.channels[channel].publish('event', 'foo') diff --git a/test/ably/sync/rest/restchannelhistory_test.py b/test/ably/sync/rest/restchannelhistory_test.py index 3c82fcc8..14b86ac5 100644 --- a/test/ably/sync/rest/restchannelhistory_test.py +++ b/test/ably/sync/rest/restchannelhistory_test.py @@ -176,7 +176,7 @@ def test_channel_history_time_forwards(self): history0.publish('history%d' % i, str(i)) history = history0.history(direction='forwards', start=interval_start, - end=interval_end) + end=interval_end) messages = history.items assert 20 == len(messages) @@ -202,7 +202,7 @@ def test_channel_history_time_backwards(self): history0.publish('history%d' % i, str(i)) history = history0.history(direction='backwards', start=interval_start, - end=interval_end) + end=interval_end) messages = history.items assert 20 == len(messages) diff --git a/test/ably/sync/rest/restchannelpublish_test.py b/test/ably/sync/rest/restchannelpublish_test.py index a3c1ebcb..38bfb1b9 100644 --- a/test/ably/sync/rest/restchannelpublish_test.py +++ b/test/ably/sync/rest/restchannelpublish_test.py @@ -298,8 +298,8 @@ def test_publish_message_with_client_id_on_identified_client(self): def test_publish_message_with_wrong_client_id_on_implicit_identified_client(self): new_token = self.ably.auth.authorize(token_params={'client_id': uuid.uuid4().hex}) new_ably = TestApp.get_ably_rest(key=None, - token=new_token.token, - use_binary_protocol=self.use_binary_protocol) + token=new_token.token, + use_binary_protocol=self.use_binary_protocol) channel = new_ably.channels[ self.get_channel_name('persisted:wrong_client_id_implicit_client')] diff --git a/test/ably/sync/rest/restinit_test.py b/test/ably/sync/rest/restinit_test.py index 84743360..8a6864ad 100644 --- a/test/ably/sync/rest/restinit_test.py +++ b/test/ably/sync/rest/restinit_test.py @@ -165,7 +165,7 @@ def test_with_no_auth_params(self): # RSA10k def test_query_time_param(self): ably = TestApp.get_ably_rest(query_time=True, - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) timestamp = ably.auth._timestamp with patch('ably.rest.rest.AblyRest.time', wraps=ably.time) as server_time,\ diff --git a/test/ably/sync/rest/resttoken_test.py b/test/ably/sync/rest/resttoken_test.py index f43bcbd8..d31e9441 100644 --- a/test/ably/sync/rest/resttoken_test.py +++ b/test/ably/sync/rest/resttoken_test.py @@ -216,8 +216,8 @@ def auth_callback(token_params): return token_request ably = TestApp.get_ably_rest(key=None, - auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) + auth_callback=auth_callback, + use_binary_protocol=self.use_binary_protocol) token = ably.auth.authorize() assert isinstance(token, TokenDetails) @@ -232,8 +232,8 @@ def auth_callback(token_params): return token_request.to_dict() ably = TestApp.get_ably_rest(key=None, - auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) + auth_callback=auth_callback, + use_binary_protocol=self.use_binary_protocol) token = ably.auth.authorize() assert isinstance(token, TokenDetails) @@ -308,7 +308,7 @@ def auth_callback(token_params): return token_request ably = TestApp.get_ably_rest(key=None, auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) + use_binary_protocol=self.use_binary_protocol) token = ably.auth.authorize() From 65b7936cc5b709061a65b663b87b872e03d08233 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:35:33 +0530 Subject: [PATCH 19/52] Fixed indentation issues with generated test files --- test/ably/sync/rest/restauth_test.py | 4 ++-- unasync.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/ably/sync/rest/restauth_test.py b/test/ably/sync/rest/restauth_test.py index 7f601156..b2845390 100644 --- a/test/ably/sync/rest/restauth_test.py +++ b/test/ably/sync/rest/restauth_test.py @@ -442,8 +442,8 @@ def test_when_auth_url_has_query_string(self): auth_route = respx.get('http://www.example.com', params={'with': 'query', 'spam': 'eggs'}).mock( return_value=Response(status_code=200, content='token_string', headers={"Content-Type": "text/plain"})) ably.auth.request_token(auth_url=url, - auth_headers=headers, - auth_params={'spam': 'eggs'}) + auth_headers=headers, + auth_params={'spam': 'eggs'}) assert auth_route.called ably.close() diff --git a/unasync.py b/unasync.py index 9d6d2d76..aa82a7f0 100644 --- a/unasync.py +++ b/unasync.py @@ -89,7 +89,7 @@ def _unasync_tokens(self, tokens: list): new_tokens.append(token) token_counter = token_counter + 1 next_newline_token = tokens[token_counter] - if len(next_newline_token.src) >= 6 and tokens[token_counter+1].utf8_byte_offset > async_await_offset: + if len(next_newline_token.src) >= 6 and tokens[token_counter+1].utf8_byte_offset >= async_await_offset + 6: new_tab_indentation = next_newline_token.src[:-6] # remove last 6 white spaces next_newline_token = next_newline_token._replace(src=new_tab_indentation) new_tokens.append(next_newline_token) @@ -105,8 +105,13 @@ def _unasync_tokens(self, tokens: list): if token.src in ["async", "await"]: # When removing async or await, we want to skip the following whitespace token_counter = token_counter + 2 - if (tokens[token_counter].src == 'def' or tokens[token_counter + 1].src == '(' or - tokens[token_counter + 2].src == '(' or tokens[token_counter + 3].src == "("): + is_async_start = tokens[token_counter].src == 'def' + is_await_start = False + for i in range(token_counter, token_counter + 6): + if tokens[i].src == '(': + is_await_start = True + break + if is_async_start or is_await_start: # Fix indentation issues for async/await fn definition/call async_await_offset = token.utf8_byte_offset async_await_block_started = True From 22836bc4a3190fe8dec32bf47702c56a7cbc5b30 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:48:14 +0530 Subject: [PATCH 20/52] Fixed indentation issues for restcapability --- test/ably/rest/restcapability_test.py | 3 +-- test/ably/sync/rest/restcapability_test.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/ably/rest/restcapability_test.py b/test/ably/rest/restcapability_test.py index 0182dcb0..f7c761ab 100644 --- a/test/ably/rest/restcapability_test.py +++ b/test/ably/rest/restcapability_test.py @@ -21,8 +21,7 @@ def per_protocol_setup(self, use_binary_protocol): async def test_blanket_intersection_with_key(self): key = self.test_vars['keys'][1] - token_details = await self.ably.auth.request_token(key_name=key['key_name'], - key_secret=key['key_secret']) + token_details = await self.ably.auth.request_token(key_name=key['key_name'], key_secret=key['key_secret']) expected_capability = Capability(key["capability"]) assert token_details.token is not None, "Expected token" assert expected_capability == token_details.capability, "Unexpected capability." diff --git a/test/ably/sync/rest/restcapability_test.py b/test/ably/sync/rest/restcapability_test.py index 486f148c..224c5d66 100644 --- a/test/ably/sync/rest/restcapability_test.py +++ b/test/ably/sync/rest/restcapability_test.py @@ -21,8 +21,7 @@ def per_protocol_setup(self, use_binary_protocol): def test_blanket_intersection_with_key(self): key = self.test_vars['keys'][1] - token_details = self.ably.auth.request_token(key_name=key['key_name'], - key_secret=key['key_secret']) + token_details = self.ably.auth.request_token(key_name=key['key_name'], key_secret=key['key_secret']) expected_capability = Capability(key["capability"]) assert token_details.token is not None, "Expected token" assert expected_capability == token_details.capability, "Unexpected capability." From 68c30481e4776fcf93eb67ec3dada689bbd25ab0 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:51:40 +0530 Subject: [PATCH 21/52] Reformatted unasync file, removed unasync_test file --- unasync.py | 3 +- unasync_test.py | 213 ------------------------------------------------ 2 files changed, 2 insertions(+), 214 deletions(-) delete mode 100644 unasync_test.py diff --git a/unasync.py b/unasync.py index aa82a7f0..aa55a84b 100644 --- a/unasync.py +++ b/unasync.py @@ -89,7 +89,8 @@ def _unasync_tokens(self, tokens: list): new_tokens.append(token) token_counter = token_counter + 1 next_newline_token = tokens[token_counter] - if len(next_newline_token.src) >= 6 and tokens[token_counter+1].utf8_byte_offset >= async_await_offset + 6: + if (len(next_newline_token.src) >= 6 and + tokens[token_counter + 1].utf8_byte_offset >= async_await_offset + 6): new_tab_indentation = next_newline_token.src[:-6] # remove last 6 white spaces next_newline_token = next_newline_token._replace(src=new_tab_indentation) new_tokens.append(next_newline_token) diff --git a/unasync_test.py b/unasync_test.py deleted file mode 100644 index 692e86cb..00000000 --- a/unasync_test.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Top-level package for unasync.""" - -import collections -import glob -import os -import tokenize as std_tokenize - -import tokenize_rt - -_ASYNC_TO_SYNC = { - "__aenter__": "__enter__", - "__aexit__": "__exit__", - "__aiter__": "__iter__", - "__anext__": "__next__", - "asynccontextmanager": "contextmanager", - "AsyncIterable": "Iterable", - "AsyncIterator": "Iterator", - "AsyncGenerator": "Generator", - # TODO StopIteration is still accepted in Python 2, but the right change - # is 'raise StopAsyncIteration' -> 'return' since we want to use unasynced - # code in Python 3.7+ - "StopAsyncIteration": "StopIteration", - "AsyncClient": "Client", - "aclose": "close", - "asyncSetUp": "setUp", - "asyncTearDown": "tearDown", - "AsyncMock": "Mock" -} - -_IMPORTS_REPLACE = { - -} - -_STRING_REPLACE = { -} - - -class Rule: - """A single set of rules for 'unasync'ing file(s)""" - - def __init__(self, fromdir, todir, additional_replacements=None): - self.fromdir = fromdir.replace("/", os.sep) - self.todir = todir.replace("/", os.sep) - - # Add any additional user-defined token replacements to our list. - self.token_replacements = _ASYNC_TO_SYNC.copy() - for key, val in (additional_replacements or {}).items(): - self.token_replacements[key] = val - - def _match(self, filepath): - """Determines if a Rule matches a given filepath and if so - returns a higher comparable value if the match is more specific. - """ - file_segments = [x for x in filepath.split(os.sep) if x] - from_segments = [x for x in self.fromdir.split(os.sep) if x] - len_from_segments = len(from_segments) - - if len_from_segments > len(file_segments): - return False - - for i in range(len(file_segments) - len_from_segments + 1): - if file_segments[i: i + len_from_segments] == from_segments: - return len_from_segments, i - - return False - - def _unasync_file(self, filepath): - with open(filepath, "rb") as f: - encoding, _ = std_tokenize.detect_encoding(f.readline) - - with open(filepath, "rt", encoding=encoding) as f: - tokens = tokenize_rt.src_to_tokens(f.read()) - tokens = self._unasync_tokens(tokens) - result = tokenize_rt.tokens_to_src(tokens) - outfilepath = filepath.replace(self.fromdir, self.todir) - os.makedirs(os.path.dirname(outfilepath), exist_ok=True) - with open(outfilepath, "wb") as f: - f.write(result.encode(encoding)) - - def _unasync_tokens(self, tokens: list): - new_tokens = [] - token_counter = 0 - while token_counter < len(tokens): - token = tokens[token_counter] - if token.src in ["async", "await"]: - # When removing async or await, we want to skip the following whitespace - token_counter = token_counter + 2 - continue - elif token.name == "NAME": - if token.src == "from": - if tokens[token_counter + 1].src == " ": - token_counter = self._replace_import(tokens, token_counter, new_tokens) - continue - else: - token = token._replace(src=self._unasync_name(token.src)) - elif token.name == "STRING": - src_token = token.src.replace("'", "") - if _STRING_REPLACE.get(src_token) is not None: - new_token = f"'{_STRING_REPLACE[src_token]}'" - token = token._replace(src=new_token) - - new_tokens.append(token) - token_counter = token_counter + 1 - - return new_tokens - - # for i, token in enumerate(tokens): - # if skip_next: - # skip_next = False - # continue - # - # if token.src in ["async", "await"]: - # # When removing async or await, we want to skip the following whitespace - # # so that `print(await stuff)` becomes `print(stuff)` and not `print( stuff)` - # skip_next = True - # else: - # if token.name == "NAME": - # token = token._replace(src=self._unasync_name(token.src)) - # elif token.name == "STRING": - # left_quote, name, right_quote = ( - # token.src[0], - # token.src[1:-1], - # token.src[-1], - # ) - # token = token._replace( - # src=left_quote + self._unasync_name(name) + right_quote - # ) - # - # yield token - - def _replace_import(self, tokens, token_counter, new_tokens: list): - new_tokens.append(tokens[token_counter]) - new_tokens.append(tokens[token_counter + 1]) - - full_lib_name = '' - lib_name_counter = token_counter + 2 - if len(_IMPORTS_REPLACE.keys()) == 0: - return lib_name_counter - - while True: - if tokens[lib_name_counter].src == " ": - break - full_lib_name = full_lib_name + tokens[lib_name_counter].src - lib_name_counter = lib_name_counter + 1 - - for key, value in _IMPORTS_REPLACE.items(): - if key in full_lib_name: - updated_lib_name = full_lib_name.replace(key, value) - for lib_name_part in updated_lib_name.split("."): - new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) - new_tokens.append(tokenize_rt.Token("OP", ".")) - new_tokens.pop() - return lib_name_counter - - lib_name_counter = token_counter + 2 - return lib_name_counter - - def _unasync_name(self, name): - if name in self.token_replacements: - return self.token_replacements[name] - # Convert classes prefixed with 'Async' into 'Sync' - # elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): - # return "Sync" + name[5:] - return name - - -def unasync_files(fpath_list, rules): - for f in fpath_list: - found_rule = None - found_weight = None - - for rule in rules: - weight = rule._match(f) - if weight and (found_weight is None or weight > found_weight): - found_rule = rule - found_weight = weight - - if found_rule: - found_rule._unasync_file(f) - - -_IMPORTS_REPLACE["ably"] = "ably.sync" -_IMPORTS_REPLACE["test.ably"] = "test.ably.sync" - -_STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' -_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.Auth.request_token' -_STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' - -Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) - -src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -os.makedirs(dest_dir_path, exist_ok=True) - - -def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, file_name_regex), recursive=True) - - -src_files = find_files(src_dir_path, "*.py") -unasync_files(src_files, (_DEFAULT_RULE,)) - -# round 2 -src_dir_path = os.path.join(os.getcwd(), "test", "ably") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), - os.path.join(os.getcwd(), "test", "ably", "utils.py")] - -unasync_files(src_files, (_DEFAULT_RULE,)) From b7a95b8f1a3e21b20a0d2a3a506f419a31611a1c Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 16:55:29 +0530 Subject: [PATCH 22/52] Fixed test names warnings as per flake8 --- test/ably/rest/restauth_test.py | 4 ++-- test/ably/sync/rest/restauth_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/ably/rest/restauth_test.py b/test/ably/rest/restauth_test.py index a6ac0ceb..5e647920 100644 --- a/test/ably/rest/restauth_test.py +++ b/test/ably/rest/restauth_test.py @@ -346,7 +346,7 @@ async def test_with_key(self): @dont_vary_protocol @respx.mock - async def test_with_auth_url_headers_and_params_POST(self): # noqa: N802 + async def test_with_auth_url_headers_and_params_http_post(self): # noqa: N802 url = 'http://www.example.com' headers = {'foo': 'bar'} ably = await TestApp.get_ably_rest(key=None, auth_url=url) @@ -381,7 +381,7 @@ def call_back(request): @dont_vary_protocol @respx.mock - async def test_with_auth_url_headers_and_params_GET(self): # noqa: N802 + async def test_with_auth_url_headers_and_params_http_get(self): # noqa: N802 url = 'http://www.example.com' headers = {'foo': 'bar'} ably = await TestApp.get_ably_rest( diff --git a/test/ably/sync/rest/restauth_test.py b/test/ably/sync/rest/restauth_test.py index b2845390..660f1ae6 100644 --- a/test/ably/sync/rest/restauth_test.py +++ b/test/ably/sync/rest/restauth_test.py @@ -346,7 +346,7 @@ def test_with_key(self): @dont_vary_protocol @respx.mock - def test_with_auth_url_headers_and_params_POST(self): # noqa: N802 + def test_with_auth_url_headers_and_params_http_post(self): # noqa: N802 url = 'http://www.example.com' headers = {'foo': 'bar'} ably = TestApp.get_ably_rest(key=None, auth_url=url) @@ -381,7 +381,7 @@ def call_back(request): @dont_vary_protocol @respx.mock - def test_with_auth_url_headers_and_params_GET(self): # noqa: N802 + def test_with_auth_url_headers_and_params_http_get(self): # noqa: N802 url = 'http://www.example.com' headers = {'foo': 'bar'} ably = TestApp.get_ably_rest( From b37c5aa284b316fd1dee0fe1984ab6872d0ce75f Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 17:16:12 +0530 Subject: [PATCH 23/52] prefixed tests with different name to avoid pytest run issues --- .../sync/rest/{encoders_test.py => sync_encoders_test.py} | 0 .../sync/rest/{restauth_test.py => sync_restauth_test.py} | 0 ...restcapability_test.py => sync_restcapability_test.py} | 0 ...nelhistory_test.py => sync_restchannelhistory_test.py} | 0 ...nelpublish_test.py => sync_restchannelpublish_test.py} | 0 .../{restchannels_test.py => sync_restchannels_test.py} | 0 ...annelstatus_test.py => sync_restchannelstatus_test.py} | 0 .../rest/{restcrypto_test.py => sync_restcrypto_test.py} | 0 .../sync/rest/{resthttp_test.py => sync_resthttp_test.py} | 0 .../sync/rest/{restinit_test.py => sync_restinit_test.py} | 0 ...tedresult_test.py => sync_restpaginatedresult_test.py} | 0 .../{restpresence_test.py => sync_restpresence_test.py} | 0 .../sync/rest/{restpush_test.py => sync_restpush_test.py} | 0 .../{restrequest_test.py => sync_restrequest_test.py} | 0 .../rest/{reststats_test.py => sync_reststats_test.py} | 0 .../sync/rest/{resttime_test.py => sync_resttime_test.py} | 0 .../rest/{resttoken_test.py => sync_resttoken_test.py} | 0 unasync.py | 8 +++++--- 18 files changed, 5 insertions(+), 3 deletions(-) rename test/ably/sync/rest/{encoders_test.py => sync_encoders_test.py} (100%) rename test/ably/sync/rest/{restauth_test.py => sync_restauth_test.py} (100%) rename test/ably/sync/rest/{restcapability_test.py => sync_restcapability_test.py} (100%) rename test/ably/sync/rest/{restchannelhistory_test.py => sync_restchannelhistory_test.py} (100%) rename test/ably/sync/rest/{restchannelpublish_test.py => sync_restchannelpublish_test.py} (100%) rename test/ably/sync/rest/{restchannels_test.py => sync_restchannels_test.py} (100%) rename test/ably/sync/rest/{restchannelstatus_test.py => sync_restchannelstatus_test.py} (100%) rename test/ably/sync/rest/{restcrypto_test.py => sync_restcrypto_test.py} (100%) rename test/ably/sync/rest/{resthttp_test.py => sync_resthttp_test.py} (100%) rename test/ably/sync/rest/{restinit_test.py => sync_restinit_test.py} (100%) rename test/ably/sync/rest/{restpaginatedresult_test.py => sync_restpaginatedresult_test.py} (100%) rename test/ably/sync/rest/{restpresence_test.py => sync_restpresence_test.py} (100%) rename test/ably/sync/rest/{restpush_test.py => sync_restpush_test.py} (100%) rename test/ably/sync/rest/{restrequest_test.py => sync_restrequest_test.py} (100%) rename test/ably/sync/rest/{reststats_test.py => sync_reststats_test.py} (100%) rename test/ably/sync/rest/{resttime_test.py => sync_resttime_test.py} (100%) rename test/ably/sync/rest/{resttoken_test.py => sync_resttoken_test.py} (100%) diff --git a/test/ably/sync/rest/encoders_test.py b/test/ably/sync/rest/sync_encoders_test.py similarity index 100% rename from test/ably/sync/rest/encoders_test.py rename to test/ably/sync/rest/sync_encoders_test.py diff --git a/test/ably/sync/rest/restauth_test.py b/test/ably/sync/rest/sync_restauth_test.py similarity index 100% rename from test/ably/sync/rest/restauth_test.py rename to test/ably/sync/rest/sync_restauth_test.py diff --git a/test/ably/sync/rest/restcapability_test.py b/test/ably/sync/rest/sync_restcapability_test.py similarity index 100% rename from test/ably/sync/rest/restcapability_test.py rename to test/ably/sync/rest/sync_restcapability_test.py diff --git a/test/ably/sync/rest/restchannelhistory_test.py b/test/ably/sync/rest/sync_restchannelhistory_test.py similarity index 100% rename from test/ably/sync/rest/restchannelhistory_test.py rename to test/ably/sync/rest/sync_restchannelhistory_test.py diff --git a/test/ably/sync/rest/restchannelpublish_test.py b/test/ably/sync/rest/sync_restchannelpublish_test.py similarity index 100% rename from test/ably/sync/rest/restchannelpublish_test.py rename to test/ably/sync/rest/sync_restchannelpublish_test.py diff --git a/test/ably/sync/rest/restchannels_test.py b/test/ably/sync/rest/sync_restchannels_test.py similarity index 100% rename from test/ably/sync/rest/restchannels_test.py rename to test/ably/sync/rest/sync_restchannels_test.py diff --git a/test/ably/sync/rest/restchannelstatus_test.py b/test/ably/sync/rest/sync_restchannelstatus_test.py similarity index 100% rename from test/ably/sync/rest/restchannelstatus_test.py rename to test/ably/sync/rest/sync_restchannelstatus_test.py diff --git a/test/ably/sync/rest/restcrypto_test.py b/test/ably/sync/rest/sync_restcrypto_test.py similarity index 100% rename from test/ably/sync/rest/restcrypto_test.py rename to test/ably/sync/rest/sync_restcrypto_test.py diff --git a/test/ably/sync/rest/resthttp_test.py b/test/ably/sync/rest/sync_resthttp_test.py similarity index 100% rename from test/ably/sync/rest/resthttp_test.py rename to test/ably/sync/rest/sync_resthttp_test.py diff --git a/test/ably/sync/rest/restinit_test.py b/test/ably/sync/rest/sync_restinit_test.py similarity index 100% rename from test/ably/sync/rest/restinit_test.py rename to test/ably/sync/rest/sync_restinit_test.py diff --git a/test/ably/sync/rest/restpaginatedresult_test.py b/test/ably/sync/rest/sync_restpaginatedresult_test.py similarity index 100% rename from test/ably/sync/rest/restpaginatedresult_test.py rename to test/ably/sync/rest/sync_restpaginatedresult_test.py diff --git a/test/ably/sync/rest/restpresence_test.py b/test/ably/sync/rest/sync_restpresence_test.py similarity index 100% rename from test/ably/sync/rest/restpresence_test.py rename to test/ably/sync/rest/sync_restpresence_test.py diff --git a/test/ably/sync/rest/restpush_test.py b/test/ably/sync/rest/sync_restpush_test.py similarity index 100% rename from test/ably/sync/rest/restpush_test.py rename to test/ably/sync/rest/sync_restpush_test.py diff --git a/test/ably/sync/rest/restrequest_test.py b/test/ably/sync/rest/sync_restrequest_test.py similarity index 100% rename from test/ably/sync/rest/restrequest_test.py rename to test/ably/sync/rest/sync_restrequest_test.py diff --git a/test/ably/sync/rest/reststats_test.py b/test/ably/sync/rest/sync_reststats_test.py similarity index 100% rename from test/ably/sync/rest/reststats_test.py rename to test/ably/sync/rest/sync_reststats_test.py diff --git a/test/ably/sync/rest/resttime_test.py b/test/ably/sync/rest/sync_resttime_test.py similarity index 100% rename from test/ably/sync/rest/resttime_test.py rename to test/ably/sync/rest/sync_resttime_test.py diff --git a/test/ably/sync/rest/resttoken_test.py b/test/ably/sync/rest/sync_resttoken_test.py similarity index 100% rename from test/ably/sync/rest/resttoken_test.py rename to test/ably/sync/rest/sync_resttoken_test.py diff --git a/unasync.py b/unasync.py index aa55a84b..213f338a 100644 --- a/unasync.py +++ b/unasync.py @@ -36,9 +36,10 @@ class Rule: """A single set of rules for 'unasync'ing file(s)""" - def __init__(self, fromdir, todir, additional_replacements=None): + def __init__(self, fromdir, todir, output_file_prefix="", additional_replacements=None): self.fromdir = fromdir.replace("/", os.sep) self.todir = todir.replace("/", os.sep) + self.ouput_file_prefix = output_file_prefix # Add any additional user-defined token replacements to our list. self.token_replacements = _ASYNC_TO_SYNC.copy() @@ -70,7 +71,8 @@ def _unasync_file(self, filepath): tokens = tokenize_rt.src_to_tokens(f.read()) tokens = self._unasync_tokens(tokens) result = tokenize_rt.tokens_to_src(tokens) - outfilepath = filepath.replace(self.fromdir, self.todir) + new_file_path = os.path.join(os.path.dirname(filepath), self.ouput_file_prefix + os.path.basename(filepath)) + outfilepath = new_file_path.replace(self.fromdir, self.todir) os.makedirs(os.path.dirname(outfilepath), exist_ok=True) with open(outfilepath, "wb") as f: f.write(result.encode(encoding)) @@ -226,7 +228,7 @@ def find_files(dir_path, file_name_regex) -> list[str]: src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) +_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path, output_file_prefix="sync_") os.makedirs(dest_dir_path, exist_ok=True) From 88013a935d2b1f22959fb30f663b12156ab08c78 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 17:17:54 +0530 Subject: [PATCH 24/52] Fixed flake8 issues for unasync file --- unasync.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unasync.py b/unasync.py index 213f338a..5ab64490 100644 --- a/unasync.py +++ b/unasync.py @@ -71,7 +71,8 @@ def _unasync_file(self, filepath): tokens = tokenize_rt.src_to_tokens(f.read()) tokens = self._unasync_tokens(tokens) result = tokenize_rt.tokens_to_src(tokens) - new_file_path = os.path.join(os.path.dirname(filepath), self.ouput_file_prefix + os.path.basename(filepath)) + new_file_path = os.path.join(os.path.dirname(filepath), + self.ouput_file_prefix + os.path.basename(filepath)) outfilepath = new_file_path.replace(self.fromdir, self.todir) os.makedirs(os.path.dirname(outfilepath), exist_ok=True) with open(outfilepath, "wb") as f: From d91c171838a08f6649bf5760b6d96421b756fd58 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:19:03 +0530 Subject: [PATCH 25/52] Added missing string replacements to unasync generator --- unasync.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unasync.py b/unasync.py index 5ab64490..6da7f8a6 100644 --- a/unasync.py +++ b/unasync.py @@ -224,6 +224,8 @@ def find_files(dir_path, file_name_regex) -> list[str]: _STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' _STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.Auth.request_token' _STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' +_STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.Http.post' +_STRING_REPLACE['httpx.AsyncClient.send'] = 'httpx.Client.send' Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) From d9bbe93b43990476d8cfd63947c562e483ec8fb8 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:21:23 +0530 Subject: [PATCH 26/52] Added more generic way to find submodules directory --- test/ably/rest/restchannelpublish_test.py | 5 ++--- test/ably/utils.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/ably/rest/restchannelpublish_test.py b/test/ably/rest/restchannelpublish_test.py index 6cf458eb..b38d286b 100644 --- a/test/ably/rest/restchannelpublish_test.py +++ b/test/ably/rest/restchannelpublish_test.py @@ -18,7 +18,7 @@ from ably.util import case from test.ably.testapp import TestApp -from test.ably.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase +from test.ably.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase, get_submodule_dir log = logging.getLogger(__name__) @@ -385,8 +385,7 @@ async def test_interoperability(self): 'binary': bytearray, } - root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - path = os.path.join(root_dir, 'submodules', 'test-resources', 'messages-encoding.json') + path = os.path.join(get_submodule_dir(__file__), 'submodules', 'test-resources', 'messages-encoding.json') with open(path) as f: data = json.load(f) for input_msg in data['messages']: diff --git a/test/ably/utils.py b/test/ably/utils.py index cb0a5b0d..0edddb90 100644 --- a/test/ably/utils.py +++ b/test/ably/utils.py @@ -1,8 +1,10 @@ import functools +import os import random import string import unittest import sys + if sys.version_info >= (3, 8): from unittest import IsolatedAsyncioTestCase else: @@ -90,8 +92,8 @@ async def test_decorated(self, *args, **kwargs): await fn(self, *args, **kwargs) unpatch(patcher) - assert len(responses) >= 1,\ - "If your test doesn't make any requests, use the @dont_vary_protocol decorator" + assert len(responses) >= 1, \ + "If your test doesn't make any requests, use the @dont_vary_protocol decorator" for response in responses: # In HTTP/2 some header fields are optional in case of 204 status code @@ -107,6 +109,7 @@ async def test_decorated(self, *args, **kwargs): msgpack.unpackb(response.content) return test_decorated + return test_decorator @@ -122,11 +125,11 @@ def per_protocol_setup(self, use_binary_protocol): is called * exclude tests with the @dont_vary_protocol decorator """ + def __new__(cls, clsname, bases, dct): for key, value in tuple(dct.items()): if key.startswith('test') and not getattr(value, 'dont_vary_protocol', False): - wrapper_bin = cls.wrap_as('bin', key, value) wrapper_text = cls.wrap_as('text', key, value) @@ -145,6 +148,7 @@ async def wrapper(self): if hasattr(self, 'per_protocol_setup'): self.per_protocol_setup(ttype == 'bin') await old_func(self) + wrapper.__name__ = old_name + '_' + ttype return wrapper @@ -166,3 +170,11 @@ def new_dict(src, **kw): def get_random_key(d): return random.choice(list(d)) + + +def get_submodule_dir(filepath): + root_dir = os.path.dirname(filepath) + while True: + if os.path.exists(os.path.join(root_dir, 'submodules')): + return os.path.join(root_dir, 'submodules') + root_dir = os.path.dirname(root_dir) From 8fe44b5504051d3623de8386a3abd1373d0cb4ed Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:35:50 +0530 Subject: [PATCH 27/52] Refactored unasync, added more string replacements to fix tests --- test/ably/rest/restchannelpublish_test.py | 2 +- unasync.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/test/ably/rest/restchannelpublish_test.py b/test/ably/rest/restchannelpublish_test.py index b38d286b..9a51a76d 100644 --- a/test/ably/rest/restchannelpublish_test.py +++ b/test/ably/rest/restchannelpublish_test.py @@ -385,7 +385,7 @@ async def test_interoperability(self): 'binary': bytearray, } - path = os.path.join(get_submodule_dir(__file__), 'submodules', 'test-resources', 'messages-encoding.json') + path = os.path.join(get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') with open(path) as f: data = json.load(f) for input_msg in data['messages']: diff --git a/unasync.py b/unasync.py index 6da7f8a6..3dc866b0 100644 --- a/unasync.py +++ b/unasync.py @@ -226,8 +226,11 @@ def find_files(dir_path, file_name_regex) -> list[str]: _STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' _STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.Http.post' _STRING_REPLACE['httpx.AsyncClient.send'] = 'httpx.Client.send' +_STRING_REPLACE['ably.util.exceptions.AblyException.raise_for_response'] = \ + 'ably.sync.util.exceptions.AblyException.raise_for_response' +_STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRest.time' +_STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.Auth._timestamp' -Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") @@ -235,11 +238,6 @@ def find_files(dir_path, file_name_regex) -> list[str]: os.makedirs(dest_dir_path, exist_ok=True) - -def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, file_name_regex), recursive=True) - - src_files = find_files(src_dir_path, "*.py") unasync_files(src_files, (_DEFAULT_RULE,)) From 57f1c287ae3f2fd89aa854a5f4381eb8823c89cb Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:40:22 +0530 Subject: [PATCH 28/52] Regenerated sync tests --- test/ably/rest/resttoken_test.py | 2 +- test/ably/sync/rest/sync_encoders_test.py | 38 +++++++++---------- .../sync/rest/sync_restchannelpublish_test.py | 13 +++---- test/ably/sync/rest/sync_resthttp_test.py | 10 ++--- test/ably/sync/rest/sync_restinit_test.py | 4 +- test/ably/sync/rest/sync_resttoken_test.py | 20 +++++----- test/ably/sync/utils.py | 18 +++++++-- 7 files changed, 58 insertions(+), 47 deletions(-) diff --git a/test/ably/rest/resttoken_test.py b/test/ably/rest/resttoken_test.py index a50c5ea4..7610868d 100644 --- a/test/ably/rest/resttoken_test.py +++ b/test/ably/rest/resttoken_test.py @@ -40,7 +40,7 @@ async def test_request_token_null_params(self): post_time = await self.server_time() assert token_details.token is not None, "Expected token" assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time, "Unexpected issued time" + assert token_details.issued <= post_time + 300, "Unexpected issued time" assert self.permit_all == str(token_details.capability), "Unexpected capability" async def test_request_token_explicit_timestamp(self): diff --git a/test/ably/sync/rest/sync_encoders_test.py b/test/ably/sync/rest/sync_encoders_test.py index 83d2e852..8fde66b4 100644 --- a/test/ably/sync/rest/sync_encoders_test.py +++ b/test/ably/sync/rest/sync_encoders_test.py @@ -31,7 +31,7 @@ def tearDown(self): def test_text_utf8(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', 'foó') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foó' @@ -41,7 +41,7 @@ def test_str(self): # This test only makes sense for py2 channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', 'foo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foo' @@ -50,7 +50,7 @@ def test_str(self): def test_with_binary_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args raw_data = json.loads(kwargs['body'])['data'] @@ -60,7 +60,7 @@ def test_with_binary_type(self): def test_with_bytes_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', b'foo') _, kwargs = post_mock.call_args raw_data = json.loads(kwargs['body'])['data'] @@ -70,7 +70,7 @@ def test_with_bytes_type(self): def test_with_json_dict_data(self): channel = self.ably.channels["persisted:publish"] data = {'foó': 'bár'} - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args raw_data = json.loads(json.loads(kwargs['body'])['data']) @@ -80,7 +80,7 @@ def test_with_json_dict_data(self): def test_with_json_list_data(self): channel = self.ably.channels["persisted:publish"] data = ['foó', 'bár'] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args raw_data = json.loads(json.loads(kwargs['body'])['data']) @@ -162,7 +162,7 @@ def decrypt(self, payload, options=None): def test_text_utf8(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', 'fóo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc/base64' @@ -173,7 +173,7 @@ def test_str(self): # This test only makes sense for py2 channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', 'foo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foo' @@ -183,7 +183,7 @@ def test_with_binary_type(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -196,7 +196,7 @@ def test_with_json_dict_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = {'foó': 'bár'} - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' @@ -207,7 +207,7 @@ def test_with_json_list_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = ['foó', 'bár'] - with mock.patch('ably.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' @@ -270,7 +270,7 @@ def decode(self, data): def test_text_utf8(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', 'foó') _, kwargs = post_mock.call_args @@ -280,7 +280,7 @@ def test_text_utf8(self): def test_with_binary_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -290,7 +290,7 @@ def test_with_binary_type(self): def test_with_json_dict_data(self): channel = self.ably.channels["persisted:publish"] data = {'foó': 'bár'} - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -301,7 +301,7 @@ def test_with_json_dict_data(self): def test_with_json_list_data(self): channel = self.ably.channels["persisted:publish"] data = ['foó', 'bár'] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -368,7 +368,7 @@ def decode(self, data): def test_text_utf8(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', 'fóo') _, kwargs = post_mock.call_args @@ -380,7 +380,7 @@ def test_with_binary_type(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -394,7 +394,7 @@ def test_with_json_dict_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = {'foó': 'bár'} - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -406,7 +406,7 @@ def test_with_json_list_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = ['foó', 'bár'] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args diff --git a/test/ably/sync/rest/sync_restchannelpublish_test.py b/test/ably/sync/rest/sync_restchannelpublish_test.py index 38bfb1b9..07dbcba5 100644 --- a/test/ably/sync/rest/sync_restchannelpublish_test.py +++ b/test/ably/sync/rest/sync_restchannelpublish_test.py @@ -18,7 +18,7 @@ from ably.sync.util import case from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase, get_submodule_dir log = logging.getLogger(__name__) @@ -104,7 +104,7 @@ def test_message_list_generate_one_request(self): expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish(messages=expected_messages) assert post_mock.call_count == 1 @@ -185,7 +185,7 @@ def test_publish_message_null_name_and_data_keys_arent_sent(self): channel = self.ably.channels[ self.get_channel_name('persisted:null_name_and_data_keys_arent_sent_channel')] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish(name=None, data=None) @@ -245,7 +245,7 @@ def test_publish_message_without_client_id_on_identified_client(self): channel = self.ably_with_client_id.channels[ self.get_channel_name('persisted:no_client_id_identified_client')] - with mock.patch('ably.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.Http.post', wraps=channel.ably.http.post) as post_mock: channel.publish(name='publish', data='test') @@ -385,8 +385,7 @@ def test_interoperability(self): 'binary': bytearray, } - root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - path = os.path.join(root_dir, 'submodules', 'test-resources', 'messages-encoding.json') + path = os.path.join(get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') with open(path) as f: data = json.load(f) for input_msg in data['messages']: @@ -545,7 +544,7 @@ def side_effect(*args, **kwargs): return x messages = [Message('name1', 'data1')] - with mock.patch('httpx.AsyncClient.send', side_effect=side_effect, autospec=True): + with mock.patch('httpx.Client.send', side_effect=side_effect, autospec=True): channel.publish(messages=messages) assert state['failures'] == 2 diff --git a/test/ably/sync/rest/sync_resthttp_test.py b/test/ably/sync/rest/sync_resthttp_test.py index 8b8fe771..372916ea 100644 --- a/test/ably/sync/rest/sync_resthttp_test.py +++ b/test/ably/sync/rest/sync_resthttp_test.py @@ -24,7 +24,7 @@ def test_max_retry_attempts_and_timeouts_defaults(self): assert 'http_open_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS assert 'http_request_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS - with mock.patch('httpx.AsyncClient.send', side_effect=httpx.RequestError('')) as send_mock: + with mock.patch('httpx.Client.send', side_effect=httpx.RequestError('')) as send_mock: with pytest.raises(httpx.RequestError): ably.http.make_request('GET', '/', version=Defaults.protocol_version, skip_auth=True) @@ -42,7 +42,7 @@ def sleep_and_raise(*args, **kwargs): time.sleep(0.51) raise httpx.TimeoutException('timeout') - with mock.patch('httpx.AsyncClient.send', side_effect=sleep_and_raise) as send_mock: + with mock.patch('httpx.Client.send', side_effect=sleep_and_raise) as send_mock: with pytest.raises(httpx.TimeoutException): ably.http.make_request('GET', '/', skip_auth=True) @@ -59,7 +59,7 @@ def make_url(host): return urljoin(base_url, '/') with mock.patch('httpx.Request', wraps=httpx.Request) as request_mock: - with mock.patch('httpx.AsyncClient.send', side_effect=httpx.RequestError('')) as send_mock: + with mock.patch('httpx.Client.send', side_effect=httpx.RequestError('')) as send_mock: with pytest.raises(httpx.RequestError): ably.http.make_request('GET', '/', skip_auth=True) @@ -110,7 +110,7 @@ def side_effect(*args, **kwargs): raise RuntimeError return send(args[1]) - with mock.patch('httpx.AsyncClient.send', side_effect=side_effect, autospec=True): + with mock.patch('httpx.Client.send', side_effect=side_effect, autospec=True): # The main host is called and there's an error ably.time() assert state['errors'] == 1 @@ -163,7 +163,7 @@ def raise_ably_exception(*args, **kwargs): raise AblyException(message="", status_code=500, code=50000) with mock.patch('httpx.Request', wraps=httpx.Request): - with mock.patch('ably.util.exceptions.AblyException.raise_for_response', + with mock.patch('ably.sync.util.exceptions.AblyException.raise_for_response', side_effect=raise_ably_exception) as send_mock: with pytest.raises(AblyException): ably.http.make_request('GET', '/', skip_auth=True) diff --git a/test/ably/sync/rest/sync_restinit_test.py b/test/ably/sync/rest/sync_restinit_test.py index 8a6864ad..3b50b4b0 100644 --- a/test/ably/sync/rest/sync_restinit_test.py +++ b/test/ably/sync/rest/sync_restinit_test.py @@ -168,8 +168,8 @@ def test_query_time_param(self): use_binary_protocol=self.use_binary_protocol) timestamp = ably.auth._timestamp - with patch('ably.rest.rest.AblyRest.time', wraps=ably.time) as server_time,\ - patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=ably.time) as server_time,\ + patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: ably.auth.request_token() assert local_time.call_count == 1 assert server_time.call_count == 1 diff --git a/test/ably/sync/rest/sync_resttoken_test.py b/test/ably/sync/rest/sync_resttoken_test.py index d31e9441..03e1c480 100644 --- a/test/ably/sync/rest/sync_resttoken_test.py +++ b/test/ably/sync/rest/sync_resttoken_test.py @@ -40,7 +40,7 @@ def test_request_token_null_params(self): post_time = self.server_time() assert token_details.token is not None, "Expected token" assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time, "Unexpected issued time" + assert token_details.issued <= post_time + 300, "Unexpected issued time" assert self.permit_all == str(token_details.capability), "Unexpected capability" def test_request_token_explicit_timestamp(self): @@ -123,8 +123,8 @@ def test_token_generation_with_invalid_ttl(self): def test_token_generation_with_local_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: self.ably.auth.request_token() assert local_time.called assert not server_time.called @@ -132,8 +132,8 @@ def test_token_generation_with_local_time(self): # RSA10k def test_token_generation_with_server_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: self.ably.auth.request_token(query_time=True) assert local_time.call_count == 1 assert server_time.call_count == 1 @@ -185,8 +185,8 @@ def test_key_name_and_secret_are_required(self): @dont_vary_protocol def test_with_local_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=False) assert local_time.called @@ -196,8 +196,8 @@ def test_with_local_time(self): @dont_vary_protocol def test_with_server_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=True) assert local_time.call_count == 1 @@ -332,7 +332,7 @@ def test_hmac(self): # AO2g @dont_vary_protocol def test_query_server_time(self): - with patch('ably.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time: + with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=True) assert server_time.call_count == 1 diff --git a/test/ably/sync/utils.py b/test/ably/sync/utils.py index c3d68f79..7bc4ebd7 100644 --- a/test/ably/sync/utils.py +++ b/test/ably/sync/utils.py @@ -1,8 +1,10 @@ import functools +import os import random import string import unittest import sys + if sys.version_info >= (3, 8): from unittest import IsolatedAsyncioTestCase else: @@ -90,8 +92,8 @@ def test_decorated(self, *args, **kwargs): fn(self, *args, **kwargs) unpatch(patcher) - assert len(responses) >= 1,\ - "If your test doesn't make any requests, use the @dont_vary_protocol decorator" + assert len(responses) >= 1, \ + "If your test doesn't make any requests, use the @dont_vary_protocol decorator" for response in responses: # In HTTP/2 some header fields are optional in case of 204 status code @@ -107,6 +109,7 @@ def test_decorated(self, *args, **kwargs): msgpack.unpackb(response.content) return test_decorated + return test_decorator @@ -122,11 +125,11 @@ def per_protocol_setup(self, use_binary_protocol): is called * exclude tests with the @dont_vary_protocol decorator """ + def __new__(cls, clsname, bases, dct): for key, value in tuple(dct.items()): if key.startswith('test') and not getattr(value, 'dont_vary_protocol', False): - wrapper_bin = cls.wrap_as('bin', key, value) wrapper_text = cls.wrap_as('text', key, value) @@ -145,6 +148,7 @@ def wrapper(self): if hasattr(self, 'per_protocol_setup'): self.per_protocol_setup(ttype == 'bin') old_func(self) + wrapper.__name__ = old_name + '_' + ttype return wrapper @@ -166,3 +170,11 @@ def new_dict(src, **kw): def get_random_key(d): return random.choice(list(d)) + + +def get_submodule_dir(filepath): + root_dir = os.path.dirname(filepath) + while True: + if os.path.exists(os.path.join(root_dir, 'submodules')): + return os.path.join(root_dir, 'submodules') + root_dir = os.path.dirname(root_dir) From ee6bc6448cbed77e403ff14589da9e0d5cd9a8d5 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:43:29 +0530 Subject: [PATCH 29/52] Fixed linting issue for restchannelpublish --- test/ably/rest/restchannelpublish_test.py | 5 +++-- test/ably/sync/rest/sync_restchannelpublish_test.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/ably/rest/restchannelpublish_test.py b/test/ably/rest/restchannelpublish_test.py index 9a51a76d..882bedc4 100644 --- a/test/ably/rest/restchannelpublish_test.py +++ b/test/ably/rest/restchannelpublish_test.py @@ -16,9 +16,10 @@ from ably.types.message import Message from ably.types.tokendetails import TokenDetails from ably.util import case +from test.ably import utils from test.ably.testapp import TestApp -from test.ably.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase, get_submodule_dir +from test.ably.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase log = logging.getLogger(__name__) @@ -385,7 +386,7 @@ async def test_interoperability(self): 'binary': bytearray, } - path = os.path.join(get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') + path = os.path.join(utils.get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') with open(path) as f: data = json.load(f) for input_msg in data['messages']: diff --git a/test/ably/sync/rest/sync_restchannelpublish_test.py b/test/ably/sync/rest/sync_restchannelpublish_test.py index 07dbcba5..582dc94b 100644 --- a/test/ably/sync/rest/sync_restchannelpublish_test.py +++ b/test/ably/sync/rest/sync_restchannelpublish_test.py @@ -16,9 +16,10 @@ from ably.sync.types.message import Message from ably.sync.types.tokendetails import TokenDetails from ably.sync.util import case +from test.ably.sync import utils from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase, get_submodule_dir +from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase log = logging.getLogger(__name__) @@ -385,7 +386,7 @@ def test_interoperability(self): 'binary': bytearray, } - path = os.path.join(get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') + path = os.path.join(utils.get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') with open(path) as f: data = json.load(f) for input_msg in data['messages']: From 6160dedabe7fb3605f0de9fb49eeea29a8c269d2 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Thu, 5 Oct 2023 22:55:05 +0530 Subject: [PATCH 30/52] Refactored unasync code, removed unnecessary garbage --- unasync.py | 43 ++++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/unasync.py b/unasync.py index 3dc866b0..a3d5f115 100644 --- a/unasync.py +++ b/unasync.py @@ -1,6 +1,5 @@ """Top-level package for unasync.""" -import collections import glob import os import tokenize as std_tokenize @@ -187,30 +186,25 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -_IMPORTS_REPLACE["ably"] = "ably.sync" +def find_files(dir_path, file_name_regex) -> list[str]: + return glob.glob(os.path.join(dir_path, "**", file_name_regex), recursive=True) -Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) # Source files ========================================== -src_dir_path = os.path.join(os.getcwd(), "ably") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - -os.makedirs(dest_dir_path, exist_ok=True) +_IMPORTS_REPLACE["ably"] = "ably.sync" -def find_files(dir_path, file_name_regex) -> list[str]: - return glob.glob(os.path.join(dir_path, "**", file_name_regex), recursive=True) - +src_dir_path = os.path.join(os.getcwd(), "ably") +dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") relevant_src_files = (set(find_files(src_dir_path, "*.py")) - set(find_files(dest_dir_path, "*.py"))) -unasync_files(list(relevant_src_files), (_DEFAULT_RULE,)) +unasync_files(list(relevant_src_files), [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) -# Test files ============================================== +# Test files ============================================== _ASYNC_TO_SYNC["AsyncClient"] = "Client" _ASYNC_TO_SYNC["aclose"] = "close" @@ -231,22 +225,17 @@ def find_files(dir_path, file_name_regex) -> list[str]: _STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRest.time' _STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.Auth._timestamp' - -src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path, output_file_prefix="sync_") - -os.makedirs(dest_dir_path, exist_ok=True) - -src_files = find_files(src_dir_path, "*.py") -unasync_files(src_files, (_DEFAULT_RULE,)) - -# round 2 +# round 1 src_dir_path = os.path.join(os.getcwd(), "test", "ably") dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") -_DEFAULT_RULE = Rule(fromdir=src_dir_path, todir=dest_dir_path) - src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), os.path.join(os.getcwd(), "test", "ably", "utils.py")] -unasync_files(src_files, (_DEFAULT_RULE,)) +unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) + +# round 2 +src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") +dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") +src_files = find_files(src_dir_path, "*.py") + +unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path, output_file_prefix="sync_")]) From 4a832733f79868086f3a9efcedb948e6b901ff14 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 15:54:43 +0530 Subject: [PATCH 31/52] Refactored unasync.py, added feature to rename classes, updated tests for the same --- unasync.py | 60 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/unasync.py b/unasync.py index a3d5f115..302fa55c 100644 --- a/unasync.py +++ b/unasync.py @@ -1,12 +1,10 @@ -"""Top-level package for unasync.""" - import glob import os import tokenize as std_tokenize import tokenize_rt -_ASYNC_TO_SYNC = { +_TOKEN_REPLACE = { "__aenter__": "__enter__", "__aexit__": "__exit__", "__aiter__": "__iter__", @@ -15,22 +13,18 @@ "AsyncIterable": "Iterable", "AsyncIterator": "Iterator", "AsyncGenerator": "Generator", - # TODO StopIteration is still accepted in Python 2, but the right change - # is 'raise StopAsyncIteration' -> 'return' since we want to use unasynced - # code in Python 3.7+ "StopAsyncIteration": "StopIteration", - "AsyncClient": "Client", - "aclose": "close" } _IMPORTS_REPLACE = { - } - _STRING_REPLACE = { } +_CLASS_RENAME = { +} + class Rule: """A single set of rules for 'unasync'ing file(s)""" @@ -41,7 +35,7 @@ def __init__(self, fromdir, todir, output_file_prefix="", additional_replacement self.ouput_file_prefix = output_file_prefix # Add any additional user-defined token replacements to our list. - self.token_replacements = _ASYNC_TO_SYNC.copy() + self.token_replacements = _TOKEN_REPLACE.copy() for key, val in (additional_replacements or {}).items(): self.token_replacements[key] = val @@ -132,6 +126,11 @@ def _unasync_tokens(self, tokens: list): if _STRING_REPLACE.get(src_token) is not None: new_token = f"'{_STRING_REPLACE[src_token]}'" token = token._replace(src=new_token) + else: + src_token = token.src.replace("\"", "") + if _STRING_REPLACE.get(src_token) is not None: + new_token = f"\"{_STRING_REPLACE[src_token]}\"" + token = token._replace(src=new_token) new_tokens.append(token) token_counter = token_counter + 1 @@ -157,6 +156,7 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): if key in full_lib_name: updated_lib_name = full_lib_name.replace(key, value) for lib_name_part in updated_lib_name.split("."): + lib_name_part = self._unasync_name(lib_name_part) new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) new_tokens.append(tokenize_rt.Token("OP", ".")) new_tokens.pop() @@ -168,6 +168,8 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): def _unasync_name(self, name): if name in self.token_replacements: return self.token_replacements[name] + if name in _CLASS_RENAME: + return _CLASS_RENAME[name] return name @@ -192,9 +194,23 @@ def find_files(dir_path, file_name_regex) -> list[str]: # Source files ========================================== +_TOKEN_REPLACE["AsyncClient"] = "Client" +_TOKEN_REPLACE["aclose"] = "close" _IMPORTS_REPLACE["ably"] = "ably.sync" +_CLASS_RENAME["AblyRest"] = "AblyRestSync" +_CLASS_RENAME["Push"] = "PushSync" +_CLASS_RENAME["PushAdmin"] = "PushAdminSync" +_CLASS_RENAME["Channel"] = "ChannelSync" +_CLASS_RENAME["Channels"] = "ChannelsSync" +_CLASS_RENAME["Auth"] = "AuthSync" +_CLASS_RENAME["Http"] = "HttpSync" +_CLASS_RENAME["PaginatedResult"] = "PaginatedResultSync" +_CLASS_RENAME["HttpPaginatedResponse"] = "HttpPaginatedResponseSync" + +_STRING_REPLACE["Auth"] = "AuthSync" + src_dir_path = os.path.join(os.getcwd(), "ably") dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") @@ -203,27 +219,27 @@ def find_files(dir_path, file_name_regex) -> list[str]: unasync_files(list(relevant_src_files), [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) - # Test files ============================================== -_ASYNC_TO_SYNC["AsyncClient"] = "Client" -_ASYNC_TO_SYNC["aclose"] = "close" -_ASYNC_TO_SYNC["asyncSetUp"] = "setUp" -_ASYNC_TO_SYNC["asyncTearDown"] = "tearDown" -_ASYNC_TO_SYNC["AsyncMock"] = "Mock" +_TOKEN_REPLACE["asyncSetUp"] = "setUp" +_TOKEN_REPLACE["asyncTearDown"] = "tearDown" +_TOKEN_REPLACE["AsyncMock"] = "Mock" + +_TOKEN_REPLACE["_Channel__publish_request_body"] = "_ChannelSync__publish_request_body" +_TOKEN_REPLACE["_Http__client"] = "_HttpSync__client" -_IMPORTS_REPLACE["ably"] = "ably.sync" _IMPORTS_REPLACE["test.ably"] = "test.ably.sync" _STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' -_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.Auth.request_token' +_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.AuthSync.request_token' _STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' -_STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.Http.post' +_STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.HttpSync.post' _STRING_REPLACE['httpx.AsyncClient.send'] = 'httpx.Client.send' _STRING_REPLACE['ably.util.exceptions.AblyException.raise_for_response'] = \ 'ably.sync.util.exceptions.AblyException.raise_for_response' -_STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRest.time' -_STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.Auth._timestamp' +_STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRestSync.time' +_STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.AuthSync._timestamp' + # round 1 src_dir_path = os.path.join(os.getcwd(), "test", "ably") From 732d04853987273d63c3283b39b88aaafcceb1b2 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 15:55:44 +0530 Subject: [PATCH 32/52] Updated resttoken test to support assertion in sync tests --- test/ably/rest/resttoken_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ably/rest/resttoken_test.py b/test/ably/rest/resttoken_test.py index 7610868d..9e74e695 100644 --- a/test/ably/rest/resttoken_test.py +++ b/test/ably/rest/resttoken_test.py @@ -40,7 +40,7 @@ async def test_request_token_null_params(self): post_time = await self.server_time() assert token_details.token is not None, "Expected token" assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time + 300, "Unexpected issued time" + assert token_details.issued <= post_time + 500, "Unexpected issued time" assert self.permit_all == str(token_details.capability), "Unexpected capability" async def test_request_token_explicit_timestamp(self): From 4c468a875bbbb37eba9421633efe54f2184c80ad Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 15:57:07 +0530 Subject: [PATCH 33/52] Renamed relevant public classes for sync support --- ably/sync/__init__.py | 6 +- ably/sync/http/http.py | 6 +- ably/sync/http/paginatedresult.py | 4 +- ably/sync/realtime/realtime.py | 10 +-- ably/sync/realtime/realtime_channel.py | 8 +-- ably/sync/rest/auth.py | 20 +++--- ably/sync/rest/channel.py | 12 ++-- ably/sync/rest/push.py | 14 ++-- ably/sync/rest/rest.py | 24 +++---- ably/sync/types/presence.py | 6 +- test/ably/sync/rest/sync_encoders_test.py | 38 +++++----- test/ably/sync/rest/sync_restauth_test.py | 70 +++++++++--------- .../sync/rest/sync_restchannelhistory_test.py | 4 +- .../sync/rest/sync_restchannelpublish_test.py | 18 ++--- test/ably/sync/rest/sync_restchannels_test.py | 10 +-- test/ably/sync/rest/sync_resthttp_test.py | 16 ++--- test/ably/sync/rest/sync_restinit_test.py | 72 +++++++++---------- .../rest/sync_restpaginatedresult_test.py | 6 +- test/ably/sync/rest/sync_restpresence_test.py | 6 +- test/ably/sync/rest/sync_restpush_test.py | 8 +-- test/ably/sync/rest/sync_restrequest_test.py | 20 +++--- test/ably/sync/rest/sync_reststats_test.py | 4 +- test/ably/sync/rest/sync_resttoken_test.py | 24 +++---- test/ably/sync/testapp.py | 6 +- test/ably/sync/utils.py | 6 +- 25 files changed, 209 insertions(+), 209 deletions(-) diff --git a/ably/sync/__init__.py b/ably/sync/__init__.py index 296dbf0d..210c52f5 100644 --- a/ably/sync/__init__.py +++ b/ably/sync/__init__.py @@ -1,7 +1,7 @@ -from ably.sync.rest.rest import AblyRest +from ably.sync.rest.rest import AblyRestSync from ably.sync.realtime.realtime import AblyRealtime -from ably.sync.rest.auth import Auth -from ably.sync.rest.push import Push +from ably.sync.rest.auth import AuthSync +from ably.sync.rest.push import PushSync from ably.sync.types.capability import Capability from ably.sync.types.channelsubscription import PushChannelSubscription from ably.sync.types.device import DeviceDetails diff --git a/ably/sync/http/http.py b/ably/sync/http/http.py index 3fcba89b..51d0bb88 100644 --- a/ably/sync/http/http.py +++ b/ably/sync/http/http.py @@ -7,7 +7,7 @@ import httpx import msgpack -from ably.sync.rest.auth import Auth +from ably.sync.rest.auth import AuthSync from ably.sync.http.httputils import HttpUtils from ably.sync.transport.defaults import Defaults from ably.sync.util.exceptions import AblyException @@ -114,7 +114,7 @@ def __getattr__(self, attr): return getattr(self.__response, attr) -class Http: +class HttpSync: CONNECTION_RETRY_DEFAULTS = { 'http_open_timeout': 4, 'http_request_timeout': 10, @@ -171,7 +171,7 @@ def make_request(self, method, path, version=None, headers=None, body=None, params = HttpUtils.get_query_params(self.options) if not skip_auth: - if self.auth.auth_mechanism == Auth.Method.BASIC and self.preferred_scheme.lower() == 'http': + if self.auth.auth_mechanism == AuthSync.Method.BASIC and self.preferred_scheme.lower() == 'http': raise AblyException( "Cannot use Basic Auth over non-TLS connections", 401, diff --git a/ably/sync/http/paginatedresult.py b/ably/sync/http/paginatedresult.py index 4f47075a..663baad9 100644 --- a/ably/sync/http/paginatedresult.py +++ b/ably/sync/http/paginatedresult.py @@ -41,7 +41,7 @@ def format_params(params=None, direction=None, start=None, end=None, limit=None, return '?' + urlencode(params) if params else '' -class PaginatedResult: +class PaginatedResultSync: def __init__(self, http, items, content_type, rel_first, rel_next, response_processor, response): self.__http = http @@ -111,7 +111,7 @@ def paginated_query_with_request(cls, http, request, response_processor, next_rel_request, response_processor, response) -class HttpPaginatedResponse(PaginatedResult): +class HttpPaginatedResponseSync(PaginatedResultSync): @property def status_code(self): return self.response.status_code diff --git a/ably/sync/realtime/realtime.py b/ably/sync/realtime/realtime.py index 51028a08..517d9676 100644 --- a/ably/sync/realtime/realtime.py +++ b/ably/sync/realtime/realtime.py @@ -1,15 +1,15 @@ import logging import asyncio from typing import Optional -from ably.sync.realtime.realtime_channel import Channels +from ably.sync.realtime.realtime_channel import ChannelsSync from ably.sync.realtime.connection import Connection, ConnectionState -from ably.sync.rest.rest import AblyRest +from ably.sync.rest.rest import AblyRestSync log = logging.getLogger(__name__) -class AblyRealtime(AblyRest): +class AblyRealtime(AblyRestSync): """ Ably Realtime Client @@ -98,7 +98,7 @@ def __init__(self, key: Optional[str] = None, loop: Optional[asyncio.AbstractEve self.key = key self.__connection = Connection(self) - self.__channels = Channels(self) + self.__channels = ChannelsSync(self) # RTN3 if self.options.auto_connect: @@ -135,6 +135,6 @@ def connection(self) -> Connection: # RTC3, RTS1 @property - def channels(self) -> Channels: + def channels(self) -> ChannelsSync: """Returns the realtime channel object""" return self.__channels diff --git a/ably/sync/realtime/realtime_channel.py b/ably/sync/realtime/realtime_channel.py index 5ed99393..805244df 100644 --- a/ably/sync/realtime/realtime_channel.py +++ b/ably/sync/realtime/realtime_channel.py @@ -4,7 +4,7 @@ from typing import Optional, TYPE_CHECKING from ably.sync.realtime.connection import ConnectionState from ably.sync.transport.websockettransport import ProtocolMessageAction -from ably.sync.rest.channel import Channel, Channels as RestChannels +from ably.sync.rest.channel import ChannelSync, ChannelsSync as RestChannels from ably.sync.types.channelstate import ChannelState, ChannelStateChange from ably.sync.types.flags import Flag, has_flag from ably.sync.types.message import Message @@ -18,7 +18,7 @@ log = logging.getLogger(__name__) -class RealtimeChannel(EventEmitter, Channel): +class RealtimeChannel(EventEmitter, ChannelSync): """ Ably Realtime Channel @@ -59,7 +59,7 @@ def __init__(self, realtime: AblyRealtime, name: str): # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() - Channel.__init__(self, realtime, name, {}) + ChannelSync.__init__(self, realtime, name, {}) # RTL4 def attach(self) -> None: @@ -454,7 +454,7 @@ def error_reason(self) -> Optional[AblyException]: return self.__error_reason -class Channels(RestChannels): +class ChannelsSync(RestChannels): """Creates and destroys RealtimeChannel objects. Methods diff --git a/ably/sync/rest/auth.py b/ably/sync/rest/auth.py index e310b550..851a2ace 100644 --- a/ably/sync/rest/auth.py +++ b/ably/sync/rest/auth.py @@ -9,7 +9,7 @@ from ably.sync.types.options import Options if TYPE_CHECKING: - from ably.sync.rest.rest import AblyRest + from ably.sync.rest.rest import AblyRestSync from ably.sync.realtime.realtime import AblyRealtime from ably.sync.types.capability import Capability @@ -17,18 +17,18 @@ from ably.sync.types.tokenrequest import TokenRequest from ably.sync.util.exceptions import AblyAuthException, AblyException, IncompatibleClientIdException -__all__ = ["Auth"] +__all__ = ["AuthSync"] log = logging.getLogger(__name__) -class Auth: +class AuthSync: class Method: BASIC = "BASIC" TOKEN = "TOKEN" - def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options): + def __init__(self, ably: Union[AblyRestSync, AblyRealtime], options: Options): self.__ably = ably self.__auth_options = options @@ -52,7 +52,7 @@ def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options): # We have the key, no need to authenticate the client # default to using basic auth log.debug("anonymous, using basic auth") - self.__auth_mechanism = Auth.Method.BASIC + self.__auth_mechanism = AuthSync.Method.BASIC basic_key = "%s:%s" % (options.key_name, options.key_secret) basic_key = base64.b64encode(basic_key.encode('utf-8')) self.__basic_credentials = basic_key.decode('ascii') @@ -61,7 +61,7 @@ def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options): raise ValueError('If use_token_auth is False you must provide a key') # Using token auth - self.__auth_mechanism = Auth.Method.TOKEN + self.__auth_mechanism = AuthSync.Method.TOKEN if options.token_details: self.__token_details = options.token_details @@ -88,11 +88,11 @@ def get_auth_transport_param(self): auth_credentials = {} if self.auth_options.client_id: auth_credentials["client_id"] = self.auth_options.client_id - if self.__auth_mechanism == Auth.Method.BASIC: + if self.__auth_mechanism == AuthSync.Method.BASIC: key_name = self.__auth_options.key_name key_secret = self.__auth_options.key_secret auth_credentials["key"] = f"{key_name}:{key_secret}" - elif self.__auth_mechanism == Auth.Method.TOKEN: + elif self.__auth_mechanism == AuthSync.Method.TOKEN: token_details = self._ensure_valid_auth_credentials() auth_credentials["accessToken"] = token_details.token return auth_credentials @@ -106,7 +106,7 @@ def __authorize_when_necessary(self, token_params=None, auth_options=None, force return token_details def _ensure_valid_auth_credentials(self, token_params=None, auth_options=None, force=False): - self.__auth_mechanism = Auth.Method.TOKEN + self.__auth_mechanism = AuthSync.Method.TOKEN if token_params is None: token_params = dict(self.auth_options.default_token_params) else: @@ -363,7 +363,7 @@ def can_assume_client_id(self, assumed_client_id): return original_client_id == assumed_client_id def _get_auth_headers(self): - if self.__auth_mechanism == Auth.Method.BASIC: + if self.__auth_mechanism == AuthSync.Method.BASIC: # RSA7e2 if self.client_id: return { diff --git a/ably/sync/rest/channel.py b/ably/sync/rest/channel.py index f1f3f199..8804d46e 100644 --- a/ably/sync/rest/channel.py +++ b/ably/sync/rest/channel.py @@ -9,7 +9,7 @@ from methoddispatch import SingleDispatch, singledispatch import msgpack -from ably.sync.http.paginatedresult import PaginatedResult, format_params +from ably.sync.http.paginatedresult import PaginatedResultSync, format_params from ably.sync.types.channeldetails import ChannelDetails from ably.sync.types.message import Message, make_message_response_handler from ably.sync.types.presence import Presence @@ -19,7 +19,7 @@ log = logging.getLogger(__name__) -class Channel(SingleDispatch): +class ChannelSync(SingleDispatch): def __init__(self, ably, name, options): self.__ably = ably self.__name = name @@ -35,7 +35,7 @@ def history(self, direction=None, limit: int = None, start=None, end=None): path = self.__base_path + 'messages' + params message_handler = make_message_response_handler(self.__cipher) - return PaginatedResult.paginated_query( + return PaginatedResultSync.paginated_query( self.ably.http, url=path, response_processor=message_handler) def __publish_request_body(self, messages): @@ -174,7 +174,7 @@ def options(self, options): self.__cipher = cipher -class Channels: +class ChannelsSync: def __init__(self, rest): self.__ably = rest self.__all: dict = OrderedDict() @@ -184,7 +184,7 @@ def get(self, name, **kwargs): name = name.decode('ascii') if name not in self.__all: - result = self.__all[name] = Channel(self.__ably, name, kwargs) + result = self.__all[name] = ChannelSync(self.__ably, name, kwargs) else: result = self.__all[name] if len(kwargs) != 0: @@ -199,7 +199,7 @@ def __getattr__(self, name): return self.get(name) def __contains__(self, item): - if isinstance(item, Channel): + if isinstance(item, ChannelSync): name = item.name elif isinstance(item, bytes): name = item.decode('ascii') diff --git a/ably/sync/rest/push.py b/ably/sync/rest/push.py index 6133f85f..34a7ddff 100644 --- a/ably/sync/rest/push.py +++ b/ably/sync/rest/push.py @@ -1,22 +1,22 @@ from typing import Optional -from ably.sync.http.paginatedresult import PaginatedResult, format_params +from ably.sync.http.paginatedresult import PaginatedResultSync, format_params from ably.sync.types.device import DeviceDetails, device_details_response_processor from ably.sync.types.channelsubscription import PushChannelSubscription, channel_subscriptions_response_processor from ably.sync.types.channelsubscription import channels_response_processor -class Push: +class PushSync: def __init__(self, ably): self.__ably = ably - self.__admin = PushAdmin(ably) + self.__admin = PushAdminSync(ably) @property def admin(self): return self.__admin -class PushAdmin: +class PushAdminSync: def __init__(self, ably): self.__ably = ably @@ -88,7 +88,7 @@ def list(self, **params): - `**params`: the parameters used to filter the list """ path = '/push/deviceRegistrations' + format_params(params) - return PaginatedResult.paginated_query( + return PaginatedResultSync.paginated_query( self.ably.http, url=path, response_processor=device_details_response_processor) @@ -141,7 +141,7 @@ def list(self, **params): - `**params`: the parameters used to filter the list """ path = '/push/channelSubscriptions' + format_params(params) - return PaginatedResult.paginated_query(self.ably.http, url=path, + return PaginatedResultSync.paginated_query(self.ably.http, url=path, response_processor=channel_subscriptions_response_processor) def list_channels(self, **params): @@ -152,7 +152,7 @@ def list_channels(self, **params): - `**params`: the parameters used to filter the list """ path = '/push/channels' + format_params(params) - return PaginatedResult.paginated_query(self.ably.http, url=path, + return PaginatedResultSync.paginated_query(self.ably.http, url=path, response_processor=channels_response_processor) def save(self, subscription: dict): diff --git a/ably/sync/rest/rest.py b/ably/sync/rest/rest.py index 56cc3723..5f0392e1 100644 --- a/ably/sync/rest/rest.py +++ b/ably/sync/rest/rest.py @@ -2,12 +2,12 @@ from typing import Optional from urllib.parse import urlencode -from ably.sync.http.http import Http -from ably.sync.http.paginatedresult import PaginatedResult, HttpPaginatedResponse +from ably.sync.http.http import HttpSync +from ably.sync.http.paginatedresult import PaginatedResultSync, HttpPaginatedResponseSync from ably.sync.http.paginatedresult import format_params -from ably.sync.rest.auth import Auth -from ably.sync.rest.channel import Channels -from ably.sync.rest.push import Push +from ably.sync.rest.auth import AuthSync +from ably.sync.rest.channel import ChannelsSync +from ably.sync.rest.push import PushSync from ably.sync.util.exceptions import AblyException, catch_all from ably.sync.types.options import Options from ably.sync.types.stats import stats_response_processor @@ -16,7 +16,7 @@ log = logging.getLogger(__name__) -class AblyRest: +class AblyRestSync: """Ably Rest Client""" def __init__(self, key: Optional[str] = None, token: Optional[str] = None, @@ -67,13 +67,13 @@ def __init__(self, key: Optional[str] = None, token: Optional[str] = None, except AttributeError: self._is_realtime = False - self.__http = Http(self, options) - self.__auth = Auth(self, options) + self.__http = HttpSync(self, options) + self.__auth = AuthSync(self, options) self.__http.auth = self.__auth - self.__channels = Channels(self) + self.__channels = ChannelsSync(self) self.__options = options - self.__push = Push(self) + self.__push = PushSync(self) def __enter__(self): return self @@ -84,7 +84,7 @@ def stats(self, direction: Optional[str] = None, start=None, end=None, params: O """Returns the stats for this application""" formatted_params = format_params(params, direction=direction, start=start, end=end, limit=limit, unit=unit) url = '/stats' + formatted_params - return PaginatedResult.paginated_query( + return PaginatedResultSync.paginated_query( self.http, url=url, response_processor=stats_response_processor) @catch_all @@ -136,7 +136,7 @@ def response_processor(response): items = [items] return items - return HttpPaginatedResponse.paginated_query( + return HttpPaginatedResponseSync.paginated_query( self.http, method, url, version=version, body=body, headers=headers, response_processor=response_processor, raise_on_error=False) diff --git a/ably/sync/types/presence.py b/ably/sync/types/presence.py index 112c619c..35a6b498 100644 --- a/ably/sync/types/presence.py +++ b/ably/sync/types/presence.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta from urllib import parse -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from ably.sync.types.mixins import EncodeDataMixin @@ -135,7 +135,7 @@ def get(self, limit=None): path = self._path_with_qs(self.__base_path + 'presence', qs) presence_handler = make_presence_response_handler(self.__cipher) - return PaginatedResult.paginated_query( + return PaginatedResultSync.paginated_query( self.__http, url=path, response_processor=presence_handler) def history(self, limit=None, direction=None, start=None, end=None): @@ -163,7 +163,7 @@ def history(self, limit=None, direction=None, start=None, end=None): path = self._path_with_qs(self.__base_path + 'presence/history', qs) presence_handler = make_presence_response_handler(self.__cipher) - return PaginatedResult.paginated_query( + return PaginatedResultSync.paginated_query( self.__http, url=path, response_processor=presence_handler) diff --git a/test/ably/sync/rest/sync_encoders_test.py b/test/ably/sync/rest/sync_encoders_test.py index 8fde66b4..d70b22d3 100644 --- a/test/ably/sync/rest/sync_encoders_test.py +++ b/test/ably/sync/rest/sync_encoders_test.py @@ -31,7 +31,7 @@ def tearDown(self): def test_text_utf8(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', 'foó') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foó' @@ -41,7 +41,7 @@ def test_str(self): # This test only makes sense for py2 channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', 'foo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foo' @@ -50,7 +50,7 @@ def test_str(self): def test_with_binary_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args raw_data = json.loads(kwargs['body'])['data'] @@ -60,7 +60,7 @@ def test_with_binary_type(self): def test_with_bytes_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', b'foo') _, kwargs = post_mock.call_args raw_data = json.loads(kwargs['body'])['data'] @@ -70,7 +70,7 @@ def test_with_bytes_type(self): def test_with_json_dict_data(self): channel = self.ably.channels["persisted:publish"] data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args raw_data = json.loads(json.loads(kwargs['body'])['data']) @@ -80,7 +80,7 @@ def test_with_json_dict_data(self): def test_with_json_list_data(self): channel = self.ably.channels["persisted:publish"] data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args raw_data = json.loads(json.loads(kwargs['body'])['data']) @@ -162,7 +162,7 @@ def decrypt(self, payload, options=None): def test_text_utf8(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', 'fóo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc/base64' @@ -173,7 +173,7 @@ def test_str(self): # This test only makes sense for py2 channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', 'foo') _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['data'] == 'foo' @@ -183,7 +183,7 @@ def test_with_binary_type(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -196,7 +196,7 @@ def test_with_json_dict_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' @@ -207,7 +207,7 @@ def test_with_json_list_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.Http.post', new_callable=Mock) as post_mock: + with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' @@ -270,7 +270,7 @@ def decode(self, data): def test_text_utf8(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', 'foó') _, kwargs = post_mock.call_args @@ -280,7 +280,7 @@ def test_text_utf8(self): def test_with_binary_type(self): channel = self.ably.channels["persisted:publish"] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -290,7 +290,7 @@ def test_with_binary_type(self): def test_with_json_dict_data(self): channel = self.ably.channels["persisted:publish"] data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -301,7 +301,7 @@ def test_with_json_dict_data(self): def test_with_json_list_data(self): channel = self.ably.channels["persisted:publish"] data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -368,7 +368,7 @@ def decode(self, data): def test_text_utf8(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', 'fóo') _, kwargs = post_mock.call_args @@ -380,7 +380,7 @@ def test_with_binary_type(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', bytearray(b'foo')) _, kwargs = post_mock.call_args @@ -394,7 +394,7 @@ def test_with_json_dict_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args @@ -406,7 +406,7 @@ def test_with_json_list_data(self): channel = self.ably.channels.get("persisted:publish_enc", cipher=self.cipher_params) data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish('event', data) _, kwargs = post_mock.call_args diff --git a/test/ably/sync/rest/sync_restauth_test.py b/test/ably/sync/rest/sync_restauth_test.py index 660f1ae6..1a2db77d 100644 --- a/test/ably/sync/rest/sync_restauth_test.py +++ b/test/ably/sync/rest/sync_restauth_test.py @@ -11,8 +11,8 @@ from httpx import Response, Client import ably -from ably.sync import AblyRest -from ably.sync import Auth +from ably.sync import AblyRestSync +from ably.sync import AuthSync from ably.sync import AblyAuthException from ably.sync.types.tokendetails import TokenDetails @@ -33,21 +33,21 @@ def setUp(self): self.test_vars = TestApp.get_test_vars() def test_auth_init_key_only(self): - ably = AblyRest(key=self.test_vars["keys"][0]["key_str"]) - assert Auth.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"]) + assert AuthSync.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" assert ably.auth.auth_options.key_name == self.test_vars["keys"][0]['key_name'] assert ably.auth.auth_options.key_secret == self.test_vars["keys"][0]['key_secret'] def test_auth_init_token_only(self): - ably = AblyRest(token="this_is_not_really_a_token") + ably = AblyRestSync(token="this_is_not_really_a_token") - assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" def test_auth_token_details(self): td = TokenDetails() - ably = AblyRest(token_details=td) + ably = AblyRestSync(token_details=td) - assert Auth.Method.TOKEN == ably.auth.auth_mechanism + assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism assert ably.auth.token_details is td def test_auth_init_with_token_callback(self): @@ -68,21 +68,21 @@ def token_callback(token_params): pass assert callback_called, "Token callback not called" - assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" def test_auth_init_with_key_and_client_id(self): - ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], client_id='testClientId') + ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], client_id='testClientId') - assert Auth.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert AuthSync.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" assert ably.auth.client_id == 'testClientId' def test_auth_init_with_token(self): ably = TestApp.get_ably_rest(key=None, token="this_is_not_really_a_token") - assert Auth.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" # RSA11 def test_request_basic_auth_header(self): - ably = AblyRest(key_secret='foo', key_name='bar') + ably = AblyRestSync(key_secret='foo', key_name='bar') with mock.patch.object(Client, 'send') as get_mock: try: @@ -95,7 +95,7 @@ def test_request_basic_auth_header(self): # RSA7e2 def test_request_basic_auth_header_with_client_id(self): - ably = AblyRest(key_secret='foo', key_name='bar', client_id='client_id') + ably = AblyRestSync(key_secret='foo', key_name='bar', client_id='client_id') with mock.patch.object(Client, 'send') as get_mock: try: @@ -107,7 +107,7 @@ def test_request_basic_auth_header_with_client_id(self): assert client_id == base64.b64encode('client_id'.encode('ascii')).decode('utf-8') def test_request_token_auth_header(self): - ably = AblyRest(token='not_a_real_token') + ably = AblyRestSync(token='not_a_real_token') with mock.patch.object(Client, 'send') as get_mock: try: @@ -120,46 +120,46 @@ def test_request_token_auth_header(self): def test_if_cant_authenticate_via_token(self): with pytest.raises(ValueError): - AblyRest(use_token_auth=True) + AblyRestSync(use_token_auth=True) def test_use_auth_token(self): - ably = AblyRest(use_token_auth=True, key=self.test_vars["keys"][0]["key_str"]) - assert ably.auth.auth_mechanism == Auth.Method.TOKEN + ably = AblyRestSync(use_token_auth=True, key=self.test_vars["keys"][0]["key_str"]) + assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN def test_with_client_id(self): - ably = AblyRest(use_token_auth=True, client_id='client_id', key=self.test_vars["keys"][0]["key_str"]) - assert ably.auth.auth_mechanism == Auth.Method.TOKEN + ably = AblyRestSync(use_token_auth=True, client_id='client_id', key=self.test_vars["keys"][0]["key_str"]) + assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN def test_with_auth_url(self): - ably = AblyRest(auth_url='auth_url') - assert ably.auth.auth_mechanism == Auth.Method.TOKEN + ably = AblyRestSync(auth_url='auth_url') + assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN def test_with_auth_callback(self): - ably = AblyRest(auth_callback=lambda x: x) - assert ably.auth.auth_mechanism == Auth.Method.TOKEN + ably = AblyRestSync(auth_callback=lambda x: x) + assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN def test_with_token(self): - ably = AblyRest(token='a token') - assert ably.auth.auth_mechanism == Auth.Method.TOKEN + ably = AblyRestSync(token='a token') + assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN def test_default_ttl_is_1hour(self): one_hour_in_ms = 60 * 60 * 1000 assert TokenDetails.DEFAULTS['ttl'] == one_hour_in_ms def test_with_auth_method(self): - ably = AblyRest(token='a token', auth_method='POST') + ably = AblyRestSync(token='a token', auth_method='POST') assert ably.auth.auth_options.auth_method == 'POST' def test_with_auth_headers(self): - ably = AblyRest(token='a token', auth_headers={'h1': 'v1'}) + ably = AblyRestSync(token='a token', auth_headers={'h1': 'v1'}) assert ably.auth.auth_options.auth_headers == {'h1': 'v1'} def test_with_auth_params(self): - ably = AblyRest(token='a token', auth_params={'p': 'v'}) + ably = AblyRestSync(token='a token', auth_params={'p': 'v'}) assert ably.auth.auth_options.auth_params == {'p': 'v'} def test_with_default_token_params(self): - ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], + ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], default_token_params={'ttl': 12345}) assert ably.auth.auth_options.default_token_params == {'ttl': 12345} @@ -178,11 +178,11 @@ def per_protocol_setup(self, use_binary_protocol): self.use_binary_protocol = use_binary_protocol def test_if_authorize_changes_auth_mechanism_to_token(self): - assert Auth.Method.BASIC == self.ably.auth.auth_mechanism, "Unexpected Auth method mismatch" + assert AuthSync.Method.BASIC == self.ably.auth.auth_mechanism, "Unexpected Auth method mismatch" self.ably.auth.authorize() - assert Auth.Method.TOKEN == self.ably.auth.auth_mechanism, "Authorize should change the Auth method" + assert AuthSync.Method.TOKEN == self.ably.auth.auth_mechanism, "Authorize should change the Auth method" # RSA10a @dont_vary_protocol @@ -210,7 +210,7 @@ def test_authorize_returns_a_token_details(self): def test_authorize_adheres_to_request_token(self): token_params = {'ttl': 10, 'client_id': 'client_id'} auth_params = {'auth_url': 'somewhere.com', 'query_time': True} - with mock.patch('ably.sync.rest.auth.Auth.request_token', new_callable=Mock) as request_mock: + with mock.patch('ably.sync.rest.auth.AuthSync.request_token', new_callable=Mock) as request_mock: self.ably.auth.authorize(token_params, auth_params) token_called, auth_called = request_mock.call_args @@ -249,7 +249,7 @@ def test_if_parameters_are_stored_and_used_as_defaults(self): auth_options = dict(self.ably.auth.auth_options.auth_options) auth_options['auth_headers'] = {'a_headers': 'a_value'} self.ably.auth.authorize({'ttl': 555}, auth_options) - with mock.patch('ably.sync.rest.auth.Auth.request_token', + with mock.patch('ably.sync.rest.auth.AuthSync.request_token', wraps=self.ably.auth.request_token) as request_mock: self.ably.auth.authorize() @@ -261,7 +261,7 @@ def test_if_parameters_are_stored_and_used_as_defaults(self): auth_options = dict(self.ably.auth.auth_options.auth_options) auth_options['auth_headers'] = None self.ably.auth.authorize({}, auth_options) - with mock.patch('ably.sync.rest.auth.Auth.request_token', + with mock.patch('ably.sync.rest.auth.AuthSync.request_token', wraps=self.ably.auth.request_token) as request_mock: self.ably.auth.authorize() diff --git a/test/ably/sync/rest/sync_restchannelhistory_test.py b/test/ably/sync/rest/sync_restchannelhistory_test.py index 14b86ac5..2263aeaa 100644 --- a/test/ably/sync/rest/sync_restchannelhistory_test.py +++ b/test/ably/sync/rest/sync_restchannelhistory_test.py @@ -3,7 +3,7 @@ import respx from ably.sync import AblyException -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from test.ably.sync.testapp import TestApp from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase @@ -32,7 +32,7 @@ def test_channel_history_types(self): history0.publish('history3', ['This is a JSONArray message payload']) history = history0.history() - assert isinstance(history, PaginatedResult) + assert isinstance(history, PaginatedResultSync) messages = history.items assert messages is not None, "Expected non-None messages" assert 4 == len(messages), "Expected 4 messages" diff --git a/test/ably/sync/rest/sync_restchannelpublish_test.py b/test/ably/sync/rest/sync_restchannelpublish_test.py index 582dc94b..a44ab265 100644 --- a/test/ably/sync/rest/sync_restchannelpublish_test.py +++ b/test/ably/sync/rest/sync_restchannelpublish_test.py @@ -12,7 +12,7 @@ from ably.sync import api_version from ably.sync import AblyException, IncompatibleClientIdException -from ably.sync.rest.auth import Auth +from ably.sync.rest.auth import AuthSync from ably.sync.types.message import Message from ably.sync.types.tokendetails import TokenDetails from ably.sync.util import case @@ -105,7 +105,7 @@ def test_message_list_generate_one_request(self): expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish(messages=expected_messages) assert post_mock.call_count == 1 @@ -186,7 +186,7 @@ def test_publish_message_null_name_and_data_keys_arent_sent(self): channel = self.ably.channels[ self.get_channel_name('persisted:null_name_and_data_keys_arent_sent_channel')] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish(name=None, data=None) @@ -238,7 +238,7 @@ def test_token_is_bound_to_options_client_id_after_publish(self): # defined after publish assert isinstance(self.ably_with_client_id.auth.token_details, TokenDetails) assert self.ably_with_client_id.auth.token_details.client_id == self.client_id - assert self.ably_with_client_id.auth.auth_mechanism == Auth.Method.TOKEN + assert self.ably_with_client_id.auth.auth_mechanism == AuthSync.Method.TOKEN history = channel.history() assert history.items[0].client_id == self.client_id @@ -246,7 +246,7 @@ def test_publish_message_without_client_id_on_identified_client(self): channel = self.ably_with_client_id.channels[ self.get_channel_name('persisted:no_client_id_identified_client')] - with mock.patch('ably.sync.rest.rest.Http.post', + with mock.patch('ably.sync.rest.rest.HttpSync.post', wraps=channel.ably.http.post) as post_mock: channel.publish(name='publish', data='test') @@ -486,7 +486,7 @@ def test_message_serialization(self): 'id': 'foobar', } message = Message(**data) - request_body = channel._Channel__publish_request_body(messages=[message]) + request_body = channel._ChannelSync__publish_request_body(messages=[message]) input_keys = set(case.snake_to_camel(x) for x in data.keys()) assert input_keys - set(request_body) == set() @@ -496,7 +496,7 @@ def test_idempotent_library_generated(self): channel = self.ably_idempotent.channels[self.get_channel_name()] message = Message('name', 'data') - request_body = channel._Channel__publish_request_body(messages=[message]) + request_body = channel._ChannelSync__publish_request_body(messages=[message]) base_id, serial = request_body['id'].split(':') assert len(base64.b64decode(base_id)) >= 9 assert serial == '0' @@ -507,7 +507,7 @@ def test_idempotent_client_supplied(self): channel = self.ably_idempotent.channels[self.get_channel_name()] message = Message('name', 'data', id='foobar') - request_body = channel._Channel__publish_request_body(messages=[message]) + request_body = channel._ChannelSync__publish_request_body(messages=[message]) assert request_body['id'] == 'foobar' # RSL1k3 @@ -519,7 +519,7 @@ def test_idempotent_mixed_ids(self): Message('name', 'data', id='foobar'), Message('name', 'data'), ] - request_body = channel._Channel__publish_request_body(messages=messages) + request_body = channel._ChannelSync__publish_request_body(messages=messages) assert request_body[0]['id'] == 'foobar' assert 'id' not in request_body[1] diff --git a/test/ably/sync/rest/sync_restchannels_test.py b/test/ably/sync/rest/sync_restchannels_test.py index 43401d36..88587313 100644 --- a/test/ably/sync/rest/sync_restchannels_test.py +++ b/test/ably/sync/rest/sync_restchannels_test.py @@ -3,7 +3,7 @@ import pytest from ably.sync import AblyException -from ably.sync.rest.channel import Channel, Channels, Presence +from ably.sync.rest.channel import ChannelSync, ChannelsSync, Presence from ably.sync.util.crypto import generate_random_key from test.ably.sync.testapp import TestApp @@ -22,18 +22,18 @@ def tearDown(self): def test_rest_channels_attr(self): assert hasattr(self.ably, 'channels') - assert isinstance(self.ably.channels, Channels) + assert isinstance(self.ably.channels, ChannelsSync) def test_channels_get_returns_new_or_existing(self): channel = self.ably.channels.get('new_channel') - assert isinstance(channel, Channel) + assert isinstance(channel, ChannelSync) channel_same = self.ably.channels.get('new_channel') assert channel is channel_same def test_channels_get_returns_new_with_options(self): key = generate_random_key() channel = self.ably.channels.get('new_channel', cipher={'key': key}) - assert isinstance(channel, Channel) + assert isinstance(channel, ChannelSync) assert channel.cipher.secret_key is key def test_channels_get_updates_existing_with_options(self): @@ -67,7 +67,7 @@ def test_channels_iteration(self): assert isinstance(self.ably.channels, Iterable) for name, channel in zip(channel_names, self.ably.channels): - assert isinstance(channel, Channel) + assert isinstance(channel, ChannelSync) assert name == channel.name # RSN4a, RSN4b diff --git a/test/ably/sync/rest/sync_resthttp_test.py b/test/ably/sync/rest/sync_resthttp_test.py index 372916ea..0c00b55b 100644 --- a/test/ably/sync/rest/sync_resthttp_test.py +++ b/test/ably/sync/rest/sync_resthttp_test.py @@ -10,7 +10,7 @@ import respx from httpx import Response -from ably.sync import AblyRest +from ably.sync import AblyRestSync from ably.sync.transport.defaults import Defaults from ably.sync.types.options import Options from ably.sync.util.exceptions import AblyException @@ -20,7 +20,7 @@ class TestRestHttp(BaseAsyncTestCase): def test_max_retry_attempts_and_timeouts_defaults(self): - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") assert 'http_open_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS assert 'http_request_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS @@ -33,7 +33,7 @@ def test_max_retry_attempts_and_timeouts_defaults(self): ably.close() def test_cumulative_timeout(self): - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") assert 'http_max_retry_duration' in ably.http.CONNECTION_RETRY_DEFAULTS ably.options.http_max_retry_duration = 0.5 @@ -50,7 +50,7 @@ def sleep_and_raise(*args, **kwargs): ably.close() def test_host_fallback(self): - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") def make_url(host): base_url = "%s://%s:%d" % (ably.http.preferred_scheme, @@ -82,7 +82,7 @@ def make_url(host): @respx.mock def test_no_host_fallback_nor_retries_if_custom_host(self): custom_host = 'example.org' - ably = AblyRest(token="foo", rest_host=custom_host) + ably = AblyRestSync(token="foo", rest_host=custom_host) mock_route = respx.get("https://example.org").mock(side_effect=httpx.RequestError('')) @@ -132,7 +132,7 @@ def side_effect(*args, **kwargs): @respx.mock def test_no_retry_if_not_500_to_599_http_code(self): default_host = Options().get_rest_host() - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") default_url = "%s://%s:%d/" % ( ably.http.preferred_scheme, @@ -157,7 +157,7 @@ def test_500_errors(self): https://github.com/ably/ably-python/issues/160 """ - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") def raise_ably_exception(*args, **kwargs): raise AblyException(message="", status_code=500, code=50000) @@ -172,7 +172,7 @@ def raise_ably_exception(*args, **kwargs): ably.close() def test_custom_http_timeouts(self): - ably = AblyRest( + ably = AblyRestSync( token="foo", http_request_timeout=30, http_open_timeout=8, http_max_retry_count=6, http_max_retry_duration=20) diff --git a/test/ably/sync/rest/sync_restinit_test.py b/test/ably/sync/rest/sync_restinit_test.py index 3b50b4b0..327076b9 100644 --- a/test/ably/sync/rest/sync_restinit_test.py +++ b/test/ably/sync/rest/sync_restinit_test.py @@ -2,7 +2,7 @@ import pytest from httpx import Client -from ably.sync import AblyRest +from ably.sync import AblyRestSync from ably.sync import AblyException from ably.sync.transport.defaults import Defaults from ably.sync.types.tokendetails import TokenDetails @@ -18,7 +18,7 @@ def setUp(self): @dont_vary_protocol def test_key_only(self): - ably = AblyRest(key=self.test_vars["keys"][0]["key_str"]) + ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"]) assert ably.options.key_name == self.test_vars["keys"][0]["key_name"], "Key name does not match" assert ably.options.key_secret == self.test_vars["keys"][0]["key_secret"], "Key secret does not match" @@ -27,65 +27,65 @@ def per_protocol_setup(self, use_binary_protocol): @dont_vary_protocol def test_with_token(self): - ably = AblyRest(token="foo") + ably = AblyRestSync(token="foo") assert ably.options.auth_token == "foo", "Token not set at options" @dont_vary_protocol def test_with_token_details(self): td = TokenDetails() - ably = AblyRest(token_details=td) + ably = AblyRestSync(token_details=td) assert ably.options.token_details is td @dont_vary_protocol def test_with_options_token_callback(self): def token_callback(**params): return "this_is_not_really_a_token_request" - AblyRest(auth_callback=token_callback) + AblyRestSync(auth_callback=token_callback) @dont_vary_protocol def test_ambiguous_key_raises_value_error(self): with pytest.raises(ValueError, match="mutually exclusive"): - AblyRest(key=self.test_vars["keys"][0]["key_str"], key_name='x') + AblyRestSync(key=self.test_vars["keys"][0]["key_str"], key_name='x') with pytest.raises(ValueError, match="mutually exclusive"): - AblyRest(key=self.test_vars["keys"][0]["key_str"], key_secret='x') + AblyRestSync(key=self.test_vars["keys"][0]["key_str"], key_secret='x') @dont_vary_protocol def test_with_key_name_or_secret_only(self): with pytest.raises(ValueError, match="key is missing"): - AblyRest(key_name='x') + AblyRestSync(key_name='x') with pytest.raises(ValueError, match="key is missing"): - AblyRest(key_secret='x') + AblyRestSync(key_secret='x') @dont_vary_protocol def test_with_key_name_and_secret(self): - ably = AblyRest(key_name="foo", key_secret="bar") + ably = AblyRestSync(key_name="foo", key_secret="bar") assert ably.options.key_name == "foo", "Key name does not match" assert ably.options.key_secret == "bar", "Key secret does not match" @dont_vary_protocol def test_with_options_auth_url(self): - AblyRest(auth_url='not_really_an_url') + AblyRestSync(auth_url='not_really_an_url') # RSC11 @dont_vary_protocol def test_rest_host_and_environment(self): # rest host - ably = AblyRest(token='foo', rest_host="some.other.host") + ably = AblyRestSync(token='foo', rest_host="some.other.host") assert "some.other.host" == ably.options.rest_host, "Unexpected host mismatch" # environment: production - ably = AblyRest(token='foo', environment="production") + ably = AblyRestSync(token='foo', environment="production") host = ably.options.get_rest_host() assert "rest.ably.io" == host, "Unexpected host mismatch %s" % host # environment: other - ably = AblyRest(token='foo', environment="sandbox") + ably = AblyRestSync(token='foo', environment="sandbox") host = ably.options.get_rest_host() assert "sandbox-rest.ably.io" == host, "Unexpected host mismatch %s" % host # both, as per #TO3k2 with pytest.raises(ValueError): - ably = AblyRest(token='foo', rest_host="some.other.host", + ably = AblyRestSync(token='foo', rest_host="some.other.host", environment="some.other.environment") # RSC15 @@ -99,68 +99,68 @@ def test_fallback_hosts(self): # Fallback hosts specified (RSC15g1) for aux in fallback_hosts: - ably = AblyRest(token='foo', fallback_hosts=aux) + ably = AblyRestSync(token='foo', fallback_hosts=aux) assert sorted(aux) == sorted(ably.options.get_fallback_rest_hosts()) # Specify environment (RSC15g2) - ably = AblyRest(token='foo', environment='sandbox', http_max_retry_count=10) + ably = AblyRestSync(token='foo', environment='sandbox', http_max_retry_count=10) assert sorted(Defaults.get_environment_fallback_hosts('sandbox')) == sorted( ably.options.get_fallback_rest_hosts()) # Fallback hosts and environment not specified (RSC15g3) - ably = AblyRest(token='foo', http_max_retry_count=10) + ably = AblyRestSync(token='foo', http_max_retry_count=10) assert sorted(Defaults.fallback_hosts) == sorted(ably.options.get_fallback_rest_hosts()) # RSC15f - ably = AblyRest(token='foo') + ably = AblyRestSync(token='foo') assert 600000 == ably.options.fallback_retry_timeout - ably = AblyRest(token='foo', fallback_retry_timeout=1000) + ably = AblyRestSync(token='foo', fallback_retry_timeout=1000) assert 1000 == ably.options.fallback_retry_timeout @dont_vary_protocol def test_specified_realtime_host(self): - ably = AblyRest(token='foo', realtime_host="some.other.host") + ably = AblyRestSync(token='foo', realtime_host="some.other.host") assert "some.other.host" == ably.options.realtime_host, "Unexpected host mismatch" @dont_vary_protocol def test_specified_port(self): - ably = AblyRest(token='foo', port=9998, tls_port=9999) + ably = AblyRestSync(token='foo', port=9998, tls_port=9999) assert 9999 == Defaults.get_port(ably.options),\ "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port @dont_vary_protocol def test_specified_non_tls_port(self): - ably = AblyRest(token='foo', port=9998, tls=False) + ably = AblyRestSync(token='foo', port=9998, tls=False) assert 9998 == Defaults.get_port(ably.options),\ "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port @dont_vary_protocol def test_specified_tls_port(self): - ably = AblyRest(token='foo', tls_port=9999, tls=True) + ably = AblyRestSync(token='foo', tls_port=9999, tls=True) assert 9999 == Defaults.get_port(ably.options),\ "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port @dont_vary_protocol def test_tls_defaults_to_true(self): - ably = AblyRest(token='foo') + ably = AblyRestSync(token='foo') assert ably.options.tls, "Expected encryption to default to true" assert Defaults.tls_port == Defaults.get_port(ably.options), "Unexpected port mismatch" @dont_vary_protocol def test_tls_can_be_disabled(self): - ably = AblyRest(token='foo', tls=False) + ably = AblyRestSync(token='foo', tls=False) assert not ably.options.tls, "Expected encryption to be False" assert Defaults.port == Defaults.get_port(ably.options), "Unexpected port mismatch" @dont_vary_protocol def test_with_no_params(self): with pytest.raises(ValueError): - AblyRest() + AblyRestSync() @dont_vary_protocol def test_with_no_auth_params(self): with pytest.raises(ValueError): - AblyRest(port=111) + AblyRestSync(port=111) # RSA10k def test_query_time_param(self): @@ -168,8 +168,8 @@ def test_query_time_param(self): use_binary_protocol=self.use_binary_protocol) timestamp = ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRest.time', wraps=ably.time) as server_time,\ - patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=ably.time) as server_time,\ + patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: ably.auth.request_token() assert local_time.call_count == 1 assert server_time.call_count == 1 @@ -181,19 +181,19 @@ def test_query_time_param(self): @dont_vary_protocol def test_requests_over_https_production(self): - ably = AblyRest(token='token') + ably = AblyRestSync(token='token') assert 'https://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) assert ably.http.preferred_port == 443 @dont_vary_protocol def test_requests_over_http_production(self): - ably = AblyRest(token='token', tls=False) + ably = AblyRestSync(token='token', tls=False) assert 'http://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) assert ably.http.preferred_port == 80 @dont_vary_protocol def test_request_basic_auth_over_http_fails(self): - ably = AblyRest(key_secret='foo', key_name='bar', tls=False) + ably = AblyRestSync(key_secret='foo', key_name='bar', tls=False) with pytest.raises(AblyException) as excinfo: ably.http.get('/time', skip_auth=False) @@ -204,8 +204,8 @@ def test_request_basic_auth_over_http_fails(self): @dont_vary_protocol def test_environment(self): - ably = AblyRest(token='token', environment='custom') - with patch.object(Client, 'send', wraps=ably.http._Http__client.send) as get_mock: + ably = AblyRestSync(token='token', environment='custom') + with patch.object(Client, 'send', wraps=ably.http._HttpSync__client.send) as get_mock: try: ably.time() except AblyException: @@ -217,7 +217,7 @@ def test_environment(self): @dont_vary_protocol def test_accepts_custom_http_timeouts(self): - ably = AblyRest( + ably = AblyRestSync( token="foo", http_request_timeout=30, http_open_timeout=8, http_max_retry_count=6, http_max_retry_duration=20) diff --git a/test/ably/sync/rest/sync_restpaginatedresult_test.py b/test/ably/sync/rest/sync_restpaginatedresult_test.py index 348e6b47..312ce100 100644 --- a/test/ably/sync/rest/sync_restpaginatedresult_test.py +++ b/test/ably/sync/rest/sync_restpaginatedresult_test.py @@ -1,7 +1,7 @@ import respx from httpx import Response -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from test.ably.sync.testapp import TestApp from test.ably.sync.utils import BaseAsyncTestCase @@ -53,11 +53,11 @@ def setUp(self): # start intercepting requests self.mocked_api.start() - self.paginated_result = PaginatedResult.paginated_query( + self.paginated_result = PaginatedResultSync.paginated_query( self.ably.http, url='http://rest.ably.io/channels/channel_name/ch1', response_processor=lambda response: response.to_native()) - self.paginated_result_with_headers = PaginatedResult.paginated_query( + self.paginated_result_with_headers = PaginatedResultSync.paginated_query( self.ably.http, url='http://rest.ably.io/channels/channel_name/ch2', response_processor=lambda response: response.to_native()) diff --git a/test/ably/sync/rest/sync_restpresence_test.py b/test/ably/sync/rest/sync_restpresence_test.py index d3c81ab1..2789ccb0 100644 --- a/test/ably/sync/rest/sync_restpresence_test.py +++ b/test/ably/sync/rest/sync_restpresence_test.py @@ -3,7 +3,7 @@ import pytest import respx -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from ably.sync.types.presence import PresenceMessage from test.ably.sync.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseAsyncTestCase @@ -27,7 +27,7 @@ def per_protocol_setup(self, use_binary_protocol): def test_channel_presence_get(self): presence_page = self.channel.presence.get() - assert isinstance(presence_page, PaginatedResult) + assert isinstance(presence_page, PaginatedResultSync) assert len(presence_page.items) == 6 member = presence_page.items[0] assert isinstance(member, PresenceMessage) @@ -40,7 +40,7 @@ def test_channel_presence_get(self): def test_channel_presence_history(self): presence_history = self.channel.presence.history() - assert isinstance(presence_history, PaginatedResult) + assert isinstance(presence_history, PaginatedResultSync) assert len(presence_history.items) == 6 member = presence_history.items[0] assert isinstance(member, PresenceMessage) diff --git a/test/ably/sync/rest/sync_restpush_test.py b/test/ably/sync/rest/sync_restpush_test.py index c1127d2e..d8114c32 100644 --- a/test/ably/sync/rest/sync_restpush_test.py +++ b/test/ably/sync/rest/sync_restpush_test.py @@ -7,7 +7,7 @@ from ably.sync import AblyException, AblyAuthException from ably.sync import DeviceDetails, PushChannelSubscription -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from test.ably.sync.testapp import TestApp from test.ably.sync.utils import VaryByProtocolTestsMetaclass, BaseAsyncTestCase @@ -166,7 +166,7 @@ def test_admin_device_registrations_list(self): list_devices = self.ably.push.admin.device_registrations.list list_response = list_devices() - assert type(list_response) is PaginatedResult + assert type(list_response) is PaginatedResultSync assert type(list_response.items) is list assert type(list_response.items[0]) is DeviceDetails @@ -267,7 +267,7 @@ def test_admin_channel_subscriptions_list(self): list_response = list_(channel=channel) - assert type(list_response) is PaginatedResult + assert type(list_response) is PaginatedResultSync assert type(list_response.items) is list assert type(list_response.items[0]) is PushChannelSubscription @@ -297,7 +297,7 @@ def test_admin_channels_list(self): list_ = self.ably.push.admin.channel_subscriptions.list_channels list_response = list_() - assert type(list_response) is PaginatedResult + assert type(list_response) is PaginatedResultSync assert type(list_response.items) is list assert type(list_response.items[0]) is str diff --git a/test/ably/sync/rest/sync_restrequest_test.py b/test/ably/sync/rest/sync_restrequest_test.py index cad062c3..9beb3c11 100644 --- a/test/ably/sync/rest/sync_restrequest_test.py +++ b/test/ably/sync/rest/sync_restrequest_test.py @@ -2,8 +2,8 @@ import pytest import respx -from ably.sync import AblyRest -from ably.sync.http.paginatedresult import HttpPaginatedResponse +from ably.sync import AblyRestSync +from ably.sync.http.paginatedresult import HttpPaginatedResponseSync from ably.sync.transport.defaults import Defaults from test.ably.sync.testapp import TestApp from test.ably.sync.utils import BaseAsyncTestCase @@ -35,7 +35,7 @@ def test_post(self): body = {'name': 'test-post', 'data': 'lorem ipsum'} result = self.ably.request('POST', self.path, body=body, version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert isinstance(result, HttpPaginatedResponseSync) # RSC19d # HP3 assert type(result.items) is list assert len(result.items) == 1 @@ -46,11 +46,11 @@ def test_get(self): params = {'limit': 10, 'direction': 'forwards'} result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert isinstance(result, HttpPaginatedResponseSync) # RSC19d # HP2 - assert isinstance(result.next(), HttpPaginatedResponse) - assert isinstance(result.first(), HttpPaginatedResponse) + assert isinstance(result.next(), HttpPaginatedResponseSync) + assert isinstance(result.first(), HttpPaginatedResponseSync) # HP3 assert isinstance(result.items, list) @@ -70,7 +70,7 @@ def test_get(self): @dont_vary_protocol def test_not_found(self): result = self.ably.request('GET', '/not-found', version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert isinstance(result, HttpPaginatedResponseSync) # RSC19d assert result.status_code == 404 # HP4 assert result.success is False # HP5 @@ -78,7 +78,7 @@ def test_not_found(self): def test_error(self): params = {'limit': 'abc'} result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponse) # RSC19d + assert isinstance(result, HttpPaginatedResponseSync) # RSC19d assert result.status_code == 400 # HP4 assert not result.success assert result.error_code @@ -95,7 +95,7 @@ def test_headers(self): def test_timeout(self): # Timeout timeout = 0.000001 - ably = AblyRest(token="foo", http_request_timeout=timeout) + ably = AblyRestSync(token="foo", http_request_timeout=timeout) assert ably.http.http_request_timeout == timeout with pytest.raises(httpx.ReadTimeout): ably.request('GET', '/time', version=Defaults.protocol_version) @@ -117,7 +117,7 @@ def test_timeout(self): ably.close() # Bad host, no Fallback - ably = AblyRest(key=self.test_vars["keys"][0]["key_str"], + ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], rest_host='some.other.host', port=self.test_vars["port"], tls_port=self.test_vars["tls_port"], diff --git a/test/ably/sync/rest/sync_reststats_test.py b/test/ably/sync/rest/sync_reststats_test.py index a621c927..dd2c91bc 100644 --- a/test/ably/sync/rest/sync_reststats_test.py +++ b/test/ably/sync/rest/sync_reststats_test.py @@ -6,7 +6,7 @@ from ably.sync.types.stats import Stats from ably.sync.util.exceptions import AblyException -from ably.sync.http.paginatedresult import PaginatedResult +from ably.sync.http.paginatedresult import PaginatedResultSync from test.ably.sync.testapp import TestApp from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase @@ -179,7 +179,7 @@ def test_protocols(self): def test_paginated_response(self): stats_pages = self.ably.stats(**self.get_params()) - assert isinstance(stats_pages, PaginatedResult) + assert isinstance(stats_pages, PaginatedResultSync) assert isinstance(stats_pages.items[0], Stats) def test_units(self): diff --git a/test/ably/sync/rest/sync_resttoken_test.py b/test/ably/sync/rest/sync_resttoken_test.py index 03e1c480..ee3a1562 100644 --- a/test/ably/sync/rest/sync_resttoken_test.py +++ b/test/ably/sync/rest/sync_resttoken_test.py @@ -6,7 +6,7 @@ import pytest from ably.sync import AblyException -from ably.sync import AblyRest +from ably.sync import AblyRestSync from ably.sync import Capability from ably.sync.types.tokendetails import TokenDetails from ably.sync.types.tokenrequest import TokenRequest @@ -40,7 +40,7 @@ def test_request_token_null_params(self): post_time = self.server_time() assert token_details.token is not None, "Expected token" assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time + 300, "Unexpected issued time" + assert token_details.issued <= post_time + 500, "Unexpected issued time" assert self.permit_all == str(token_details.capability), "Unexpected capability" def test_request_token_explicit_timestamp(self): @@ -123,8 +123,8 @@ def test_token_generation_with_invalid_ttl(self): def test_token_generation_with_local_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: self.ably.auth.request_token() assert local_time.called assert not server_time.called @@ -132,8 +132,8 @@ def test_token_generation_with_local_time(self): # RSA10k def test_token_generation_with_server_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: self.ably.auth.request_token(query_time=True) assert local_time.call_count == 1 assert server_time.call_count == 1 @@ -185,8 +185,8 @@ def test_key_name_and_secret_are_required(self): @dont_vary_protocol def test_with_local_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=False) assert local_time.called @@ -196,8 +196,8 @@ def test_with_local_time(self): @dont_vary_protocol def test_with_server_time(self): timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.Auth._timestamp', wraps=timestamp) as local_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ + patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=True) assert local_time.call_count == 1 @@ -317,7 +317,7 @@ def auth_callback(token_params): @dont_vary_protocol def test_hmac(self): - ably = AblyRest(key_name='a_key_name', key_secret='a_secret') + ably = AblyRestSync(key_name='a_key_name', key_secret='a_secret') token_params = { 'ttl': 1000, 'nonce': 'abcde100', @@ -332,7 +332,7 @@ def test_hmac(self): # AO2g @dont_vary_protocol def test_query_server_time(self): - with patch('ably.sync.rest.rest.AblyRest.time', wraps=self.ably.time) as server_time: + with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time: self.ably.auth.create_token_request( key_name=self.key_name, key_secret=self.key_secret, query_time=True) assert server_time.call_count == 1 diff --git a/test/ably/sync/testapp.py b/test/ably/sync/testapp.py index 54c0af02..fd3e4f2d 100644 --- a/test/ably/sync/testapp.py +++ b/test/ably/sync/testapp.py @@ -2,7 +2,7 @@ import os import logging -from ably.sync.rest.rest import AblyRest +from ably.sync.rest.rest import AblyRestSync from ably.sync.types.capability import Capability from ably.sync.types.options import Options from ably.sync.util.exceptions import AblyException @@ -28,7 +28,7 @@ tls_port = 8081 -ably = AblyRest(token='not_a_real_token', +ably = AblyRestSync(token='not_a_real_token', port=port, tls_port=tls_port, tls=tls, environment=environment, use_binary_protocol=False) @@ -74,7 +74,7 @@ def get_ably_rest(**kw): test_vars = TestApp.get_test_vars() options = TestApp.get_options(test_vars, **kw) options.update(kw) - return AblyRest(**options) + return AblyRestSync(**options) @staticmethod def get_ably_realtime(**kw): diff --git a/test/ably/sync/utils.py b/test/ably/sync/utils.py index 7bc4ebd7..a45a7b39 100644 --- a/test/ably/sync/utils.py +++ b/test/ably/sync/utils.py @@ -15,7 +15,7 @@ import respx from httpx import Response -from ably.sync.http.http import Http +from ably.sync.http.http import HttpSync class BaseTestCase(unittest.TestCase): @@ -71,14 +71,14 @@ def test_something(self): responses = [] def patch(): - original = Http.make_request + original = HttpSync.make_request def fake_make_request(self, *args, **kwargs): response = original(self, *args, **kwargs) responses.append(response) return response - patcher = mock.patch.object(Http, 'make_request', fake_make_request) + patcher = mock.patch.object(HttpSync, 'make_request', fake_make_request) patcher.start() return patcher From d8d7b36db98e9abc118d0d9a6c496fa507f5ee6f Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 16:55:13 +0530 Subject: [PATCH 34/52] Fixed indentation issues caused by class rename in unasync file --- unasync.py | 52 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/unasync.py b/unasync.py index 302fa55c..405a2252 100644 --- a/unasync.py +++ b/unasync.py @@ -75,29 +75,45 @@ def _unasync_tokens(self, tokens: list): new_tokens = [] token_counter = 0 async_await_block_started = False + async_await_char_diff = -6 # (len("async") or len("await") is 6) async_await_offset = 0 + + renamed_class_call_started = False + renamed_class_char_diff = 0 + renamed_class_offset = 0 + while token_counter < len(tokens): token = tokens[token_counter] - if async_await_block_started: + if async_await_block_started or renamed_class_call_started: # Fix indentation issues for async/await fn definition/call if token.src == '\n': new_tokens.append(token) token_counter = token_counter + 1 next_newline_token = tokens[token_counter] - if (len(next_newline_token.src) >= 6 and + new_tab_src = next_newline_token.src + + if (renamed_class_call_started and + tokens[token_counter + 1].utf8_byte_offset >= renamed_class_offset): + if renamed_class_char_diff < 0: + new_tab_src = new_tab_src[:renamed_class_char_diff] + else: + new_tab_src = new_tab_src + renamed_class_char_diff * " " + + if (async_await_block_started and len(next_newline_token.src) >= 6 and tokens[token_counter + 1].utf8_byte_offset >= async_await_offset + 6): - new_tab_indentation = next_newline_token.src[:-6] # remove last 6 white spaces - next_newline_token = next_newline_token._replace(src=new_tab_indentation) - new_tokens.append(next_newline_token) - else: - new_tokens.append(next_newline_token) + new_tab_src = new_tab_src[:async_await_char_diff] # remove last 6 white spaces + + next_newline_token = next_newline_token._replace(src=new_tab_src) + new_tokens.append(next_newline_token) token_counter = token_counter + 1 continue if token.src == ')': async_await_block_started = False async_await_offset = 0 + renamed_class_call_started = False + renamed_class_char_diff = 0 if token.src in ["async", "await"]: # When removing async or await, we want to skip the following whitespace @@ -120,7 +136,18 @@ def _unasync_tokens(self, tokens: list): token_counter = self._replace_import(tokens, token_counter, new_tokens) continue else: - token = token._replace(src=self._unasync_name(token.src)) + token_new_src = self._unasync_name(token.src) + if token.src == token_new_src: + token_new_src = self._class_rename(token.src) + if token.src != token_new_src: + renamed_class_offset = token.utf8_byte_offset + renamed_class_char_diff = len(token_new_src) - len(token.src) + for i in range(token_counter, token_counter + 6): + if tokens[i].src == '(': + renamed_class_call_started = True + break + + token = token._replace(src=token_new_src) elif token.name == "STRING": src_token = token.src.replace("'", "") if _STRING_REPLACE.get(src_token) is not None: @@ -156,7 +183,7 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): if key in full_lib_name: updated_lib_name = full_lib_name.replace(key, value) for lib_name_part in updated_lib_name.split("."): - lib_name_part = self._unasync_name(lib_name_part) + lib_name_part = self._class_rename(lib_name_part) new_tokens.append(tokenize_rt.Token("NAME", lib_name_part)) new_tokens.append(tokenize_rt.Token("OP", ".")) new_tokens.pop() @@ -165,11 +192,14 @@ def _replace_import(self, tokens, token_counter, new_tokens: list): lib_name_counter = token_counter + 2 return lib_name_counter + def _class_rename(self, name): + if name in _CLASS_RENAME: + return _CLASS_RENAME[name] + return name + def _unasync_name(self, name): if name in self.token_replacements: return self.token_replacements[name] - if name in _CLASS_RENAME: - return _CLASS_RENAME[name] return name From 68e4e1b34451a3ff31664483ab79d23ac05e9e26 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 16:55:58 +0530 Subject: [PATCH 35/52] Generated sync files to resolve indentation issues --- ably/sync/rest/push.py | 4 ++-- test/ably/sync/rest/sync_restauth_test.py | 2 +- test/ably/sync/rest/sync_restinit_test.py | 2 +- test/ably/sync/rest/sync_restrequest_test.py | 8 ++++---- test/ably/sync/testapp.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ably/sync/rest/push.py b/ably/sync/rest/push.py index 34a7ddff..3bb4de40 100644 --- a/ably/sync/rest/push.py +++ b/ably/sync/rest/push.py @@ -142,7 +142,7 @@ def list(self, **params): """ path = '/push/channelSubscriptions' + format_params(params) return PaginatedResultSync.paginated_query(self.ably.http, url=path, - response_processor=channel_subscriptions_response_processor) + response_processor=channel_subscriptions_response_processor) def list_channels(self, **params): """Returns a PaginatedResult object with the list of @@ -153,7 +153,7 @@ def list_channels(self, **params): """ path = '/push/channels' + format_params(params) return PaginatedResultSync.paginated_query(self.ably.http, url=path, - response_processor=channels_response_processor) + response_processor=channels_response_processor) def save(self, subscription: dict): """Creates or updates the subscription. Returns a diff --git a/test/ably/sync/rest/sync_restauth_test.py b/test/ably/sync/rest/sync_restauth_test.py index 1a2db77d..e4f3560b 100644 --- a/test/ably/sync/rest/sync_restauth_test.py +++ b/test/ably/sync/rest/sync_restauth_test.py @@ -160,7 +160,7 @@ def test_with_auth_params(self): def test_with_default_token_params(self): ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], - default_token_params={'ttl': 12345}) + default_token_params={'ttl': 12345}) assert ably.auth.auth_options.default_token_params == {'ttl': 12345} diff --git a/test/ably/sync/rest/sync_restinit_test.py b/test/ably/sync/rest/sync_restinit_test.py index 327076b9..99837890 100644 --- a/test/ably/sync/rest/sync_restinit_test.py +++ b/test/ably/sync/rest/sync_restinit_test.py @@ -86,7 +86,7 @@ def test_rest_host_and_environment(self): # both, as per #TO3k2 with pytest.raises(ValueError): ably = AblyRestSync(token='foo', rest_host="some.other.host", - environment="some.other.environment") + environment="some.other.environment") # RSC15 @dont_vary_protocol diff --git a/test/ably/sync/rest/sync_restrequest_test.py b/test/ably/sync/rest/sync_restrequest_test.py index 9beb3c11..8c090ac7 100644 --- a/test/ably/sync/rest/sync_restrequest_test.py +++ b/test/ably/sync/rest/sync_restrequest_test.py @@ -118,10 +118,10 @@ def test_timeout(self): # Bad host, no Fallback ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], - rest_host='some.other.host', - port=self.test_vars["port"], - tls_port=self.test_vars["tls_port"], - tls=self.test_vars["tls"]) + rest_host='some.other.host', + port=self.test_vars["port"], + tls_port=self.test_vars["tls_port"], + tls=self.test_vars["tls"]) with pytest.raises(httpx.ConnectError): ably.request('GET', '/time', version=Defaults.protocol_version) ably.close() diff --git a/test/ably/sync/testapp.py b/test/ably/sync/testapp.py index fd3e4f2d..0947296f 100644 --- a/test/ably/sync/testapp.py +++ b/test/ably/sync/testapp.py @@ -29,9 +29,9 @@ ably = AblyRestSync(token='not_a_real_token', - port=port, tls_port=tls_port, tls=tls, - environment=environment, - use_binary_protocol=False) + port=port, tls_port=tls_port, tls=tls, + environment=environment, + use_binary_protocol=False) class TestApp: From a3401634fcf44bbf0143c470a1e6fe79529afe8f Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 16:58:46 +0530 Subject: [PATCH 36/52] Fixed indentation issues as per flake8 --- unasync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unasync.py b/unasync.py index 405a2252..7958682e 100644 --- a/unasync.py +++ b/unasync.py @@ -75,7 +75,7 @@ def _unasync_tokens(self, tokens: list): new_tokens = [] token_counter = 0 async_await_block_started = False - async_await_char_diff = -6 # (len("async") or len("await") is 6) + async_await_char_diff = -6 # (len("async") or len("await") is 6) async_await_offset = 0 renamed_class_call_started = False @@ -102,7 +102,7 @@ def _unasync_tokens(self, tokens: list): if (async_await_block_started and len(next_newline_token.src) >= 6 and tokens[token_counter + 1].utf8_byte_offset >= async_await_offset + 6): - new_tab_src = new_tab_src[:async_await_char_diff] # remove last 6 white spaces + new_tab_src = new_tab_src[:async_await_char_diff] # remove last 6 white spaces next_newline_token = next_newline_token._replace(src=new_tab_src) new_tokens.append(next_newline_token) From 0105b2bfc2f240ba4b7d7139838a3941a06f6408 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:05:36 +0530 Subject: [PATCH 37/52] Added extra step to generate sync rest code and tests to github workflow --- .github/workflows/check.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 7112f197..ddf6a644 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -35,5 +35,7 @@ jobs: run: poetry install -E crypto - name: Lint with flake8 run: poetry run flake8 + - name: Generate rest sync code and tests + run: poetry run python unasync.py - name: Test with pytest run: poetry run pytest From 5da20b1b864548990a17dec88579b42fec1fee99 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:07:02 +0530 Subject: [PATCH 38/52] Removed all generated sync files --- ably/sync/__init__.py | 18 - ably/sync/http/__init__.py | 0 ably/sync/http/http.py | 301 -------- ably/sync/http/httputils.py | 55 -- ably/sync/http/paginatedresult.py | 134 ---- ably/sync/realtime/__init__.py | 0 ably/sync/realtime/connection.py | 119 ---- ably/sync/realtime/connectionmanager.py | 524 -------------- ably/sync/realtime/realtime.py | 140 ---- ably/sync/realtime/realtime_channel.py | 553 --------------- ably/sync/rest/__init__.py | 0 ably/sync/rest/auth.py | 425 ------------ ably/sync/rest/channel.py | 229 ------ ably/sync/rest/push.py | 189 ----- ably/sync/rest/rest.py | 148 ---- ably/sync/transport/__init__.py | 0 ably/sync/transport/defaults.py | 63 -- ably/sync/transport/websockettransport.py | 219 ------ ably/sync/types/__init__.py | 0 ably/sync/types/authoptions.py | 157 ----- ably/sync/types/capability.py | 82 --- ably/sync/types/channeldetails.py | 116 ---- ably/sync/types/channelstate.py | 22 - ably/sync/types/channelsubscription.py | 70 -- ably/sync/types/connectiondetails.py | 20 - ably/sync/types/connectionerrors.py | 30 - ably/sync/types/connectionstate.py | 36 - ably/sync/types/device.py | 116 ---- ably/sync/types/flags.py | 19 - ably/sync/types/message.py | 233 ------- ably/sync/types/mixins.py | 75 -- ably/sync/types/options.py | 330 --------- ably/sync/types/presence.py | 174 ----- ably/sync/types/stats.py | 67 -- ably/sync/types/tokendetails.py | 97 --- ably/sync/types/tokenrequest.py | 107 --- ably/sync/types/typedbuffer.py | 104 --- ably/sync/util/__init__.py | 0 ably/sync/util/case.py | 18 - ably/sync/util/crypto.py | 179 ----- ably/sync/util/eventemitter.py | 185 ----- ably/sync/util/exceptions.py | 92 --- ably/sync/util/helper.py | 42 -- ably/sync/util/nocrypto.py | 9 - test/ably/sync/rest/sync_encoders_test.py | 456 ------------ test/ably/sync/rest/sync_restauth_test.py | 652 ------------------ .../sync/rest/sync_restcapability_test.py | 242 ------- .../sync/rest/sync_restchannelhistory_test.py | 332 --------- .../sync/rest/sync_restchannelpublish_test.py | 568 --------------- test/ably/sync/rest/sync_restchannels_test.py | 91 --- .../sync/rest/sync_restchannelstatus_test.py | 47 -- test/ably/sync/rest/sync_restcrypto_test.py | 264 ------- test/ably/sync/rest/sync_resthttp_test.py | 229 ------ test/ably/sync/rest/sync_restinit_test.py | 227 ------ .../rest/sync_restpaginatedresult_test.py | 91 --- test/ably/sync/rest/sync_restpresence_test.py | 213 ------ test/ably/sync/rest/sync_restpush_test.py | 398 ----------- test/ably/sync/rest/sync_restrequest_test.py | 132 ---- test/ably/sync/rest/sync_reststats_test.py | 310 --------- test/ably/sync/rest/sync_resttime_test.py | 43 -- test/ably/sync/rest/sync_resttoken_test.py | 342 --------- test/ably/sync/testapp.py | 115 --- test/ably/sync/utils.py | 180 ----- 63 files changed, 10429 deletions(-) delete mode 100644 ably/sync/__init__.py delete mode 100644 ably/sync/http/__init__.py delete mode 100644 ably/sync/http/http.py delete mode 100644 ably/sync/http/httputils.py delete mode 100644 ably/sync/http/paginatedresult.py delete mode 100644 ably/sync/realtime/__init__.py delete mode 100644 ably/sync/realtime/connection.py delete mode 100644 ably/sync/realtime/connectionmanager.py delete mode 100644 ably/sync/realtime/realtime.py delete mode 100644 ably/sync/realtime/realtime_channel.py delete mode 100644 ably/sync/rest/__init__.py delete mode 100644 ably/sync/rest/auth.py delete mode 100644 ably/sync/rest/channel.py delete mode 100644 ably/sync/rest/push.py delete mode 100644 ably/sync/rest/rest.py delete mode 100644 ably/sync/transport/__init__.py delete mode 100644 ably/sync/transport/defaults.py delete mode 100644 ably/sync/transport/websockettransport.py delete mode 100644 ably/sync/types/__init__.py delete mode 100644 ably/sync/types/authoptions.py delete mode 100644 ably/sync/types/capability.py delete mode 100644 ably/sync/types/channeldetails.py delete mode 100644 ably/sync/types/channelstate.py delete mode 100644 ably/sync/types/channelsubscription.py delete mode 100644 ably/sync/types/connectiondetails.py delete mode 100644 ably/sync/types/connectionerrors.py delete mode 100644 ably/sync/types/connectionstate.py delete mode 100644 ably/sync/types/device.py delete mode 100644 ably/sync/types/flags.py delete mode 100644 ably/sync/types/message.py delete mode 100644 ably/sync/types/mixins.py delete mode 100644 ably/sync/types/options.py delete mode 100644 ably/sync/types/presence.py delete mode 100644 ably/sync/types/stats.py delete mode 100644 ably/sync/types/tokendetails.py delete mode 100644 ably/sync/types/tokenrequest.py delete mode 100644 ably/sync/types/typedbuffer.py delete mode 100644 ably/sync/util/__init__.py delete mode 100644 ably/sync/util/case.py delete mode 100644 ably/sync/util/crypto.py delete mode 100644 ably/sync/util/eventemitter.py delete mode 100644 ably/sync/util/exceptions.py delete mode 100644 ably/sync/util/helper.py delete mode 100644 ably/sync/util/nocrypto.py delete mode 100644 test/ably/sync/rest/sync_encoders_test.py delete mode 100644 test/ably/sync/rest/sync_restauth_test.py delete mode 100644 test/ably/sync/rest/sync_restcapability_test.py delete mode 100644 test/ably/sync/rest/sync_restchannelhistory_test.py delete mode 100644 test/ably/sync/rest/sync_restchannelpublish_test.py delete mode 100644 test/ably/sync/rest/sync_restchannels_test.py delete mode 100644 test/ably/sync/rest/sync_restchannelstatus_test.py delete mode 100644 test/ably/sync/rest/sync_restcrypto_test.py delete mode 100644 test/ably/sync/rest/sync_resthttp_test.py delete mode 100644 test/ably/sync/rest/sync_restinit_test.py delete mode 100644 test/ably/sync/rest/sync_restpaginatedresult_test.py delete mode 100644 test/ably/sync/rest/sync_restpresence_test.py delete mode 100644 test/ably/sync/rest/sync_restpush_test.py delete mode 100644 test/ably/sync/rest/sync_restrequest_test.py delete mode 100644 test/ably/sync/rest/sync_reststats_test.py delete mode 100644 test/ably/sync/rest/sync_resttime_test.py delete mode 100644 test/ably/sync/rest/sync_resttoken_test.py delete mode 100644 test/ably/sync/testapp.py delete mode 100644 test/ably/sync/utils.py diff --git a/ably/sync/__init__.py b/ably/sync/__init__.py deleted file mode 100644 index 210c52f5..00000000 --- a/ably/sync/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from ably.sync.rest.rest import AblyRestSync -from ably.sync.realtime.realtime import AblyRealtime -from ably.sync.rest.auth import AuthSync -from ably.sync.rest.push import PushSync -from ably.sync.types.capability import Capability -from ably.sync.types.channelsubscription import PushChannelSubscription -from ably.sync.types.device import DeviceDetails -from ably.sync.types.options import Options -from ably.sync.util.crypto import CipherParams -from ably.sync.util.exceptions import AblyException, AblyAuthException, IncompatibleClientIdException - -import logging - -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) - -api_version = '3' -lib_version = '2.0.2' diff --git a/ably/sync/http/__init__.py b/ably/sync/http/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/http/http.py b/ably/sync/http/http.py deleted file mode 100644 index 51d0bb88..00000000 --- a/ably/sync/http/http.py +++ /dev/null @@ -1,301 +0,0 @@ -import functools -import logging -import time -import json -from urllib.parse import urljoin - -import httpx -import msgpack - -from ably.sync.rest.auth import AuthSync -from ably.sync.http.httputils import HttpUtils -from ably.sync.transport.defaults import Defaults -from ably.sync.util.exceptions import AblyException -from ably.sync.util.helper import is_token_error - -log = logging.getLogger(__name__) - - -def reauth_if_expired(func): - @functools.wraps(func) - def wrapper(rest, *args, **kwargs): - if kwargs.get("skip_auth"): - return func(rest, *args, **kwargs) - - # RSA4b1 Detect expired token to avoid round-trip request - auth = rest.auth - token_details = auth.token_details - if token_details and auth.time_offset is not None and auth.token_details_has_expired(): - auth.authorize() - retried = True - else: - retried = False - - try: - return func(rest, *args, **kwargs) - except AblyException as e: - if is_token_error(e) and not retried: - auth.authorize() - return func(rest, *args, **kwargs) - - raise e - - return wrapper - - -class Request: - def __init__(self, method='GET', url='/', version=None, headers=None, body=None, - skip_auth=False, raise_on_error=True): - self.__method = method - self.__headers = headers or {} - self.__body = body - self.__skip_auth = skip_auth - self.__url = url - self.__version = version - self.raise_on_error = raise_on_error - - def with_relative_url(self, relative_url): - url = urljoin(self.url, relative_url) - return Request(self.method, url, self.version, self.headers, self.body, - self.skip_auth, self.raise_on_error) - - @property - def method(self): - return self.__method - - @property - def url(self): - return self.__url - - @property - def headers(self): - return self.__headers - - @property - def body(self): - return self.__body - - @property - def skip_auth(self): - return self.__skip_auth - - @property - def version(self): - return self.__version - - -class Response: - """ - Composition for httpx.Response with delegation - """ - - def __init__(self, response): - self.__response = response - - def to_native(self): - content = self.__response.content - if not content: - return None - - content_type = self.__response.headers.get('content-type') - if isinstance(content_type, str): - if content_type.startswith('application/x-msgpack'): - return msgpack.unpackb(content) - elif content_type.startswith('application/json'): - return self.__response.json() - - raise ValueError("Unsupported content type") - - @property - def response(self): - return self.__response - - def __getattr__(self, attr): - return getattr(self.__response, attr) - - -class HttpSync: - CONNECTION_RETRY_DEFAULTS = { - 'http_open_timeout': 4, - 'http_request_timeout': 10, - 'http_max_retry_duration': 15, - } - - def __init__(self, ably, options): - options = options or {} - self.__ably = ably - self.__options = options - self.__auth = None - # Cached fallback host (RSC15f) - self.__host = None - self.__host_expires = None - self.__client = httpx.Client(http2=True) - - def close(self): - self.__client.close() - - def dump_body(self, body): - if self.options.use_binary_protocol: - return msgpack.packb(body, use_bin_type=False) - else: - return json.dumps(body, separators=(',', ':')) - - def get_rest_hosts(self): - hosts = self.options.get_rest_hosts() - host = self.__host or self.options.fallback_realtime_host - if host is None: - return hosts - - if time.time() > self.__host_expires: - self.__host = None - self.__host_expires = None - return hosts - - hosts = list(hosts) - hosts.remove(host) - hosts.insert(0, host) - return hosts - - @reauth_if_expired - def make_request(self, method, path, version=None, headers=None, body=None, - skip_auth=False, timeout=None, raise_on_error=True): - - if body is not None and type(body) not in (bytes, str): - body = self.dump_body(body) - - if body: - all_headers = HttpUtils.default_post_headers(self.options.use_binary_protocol, version=version) - else: - all_headers = HttpUtils.default_get_headers(self.options.use_binary_protocol, version=version) - - params = HttpUtils.get_query_params(self.options) - - if not skip_auth: - if self.auth.auth_mechanism == AuthSync.Method.BASIC and self.preferred_scheme.lower() == 'http': - raise AblyException( - "Cannot use Basic Auth over non-TLS connections", - 401, - 40103) - auth_headers = self.auth._get_auth_headers() - all_headers.update(auth_headers) - if headers: - all_headers.update(headers) - - timeout = (self.http_open_timeout, self.http_request_timeout) - http_max_retry_duration = self.http_max_retry_duration - requested_at = time.time() - - hosts = self.get_rest_hosts() - for retry_count, host in enumerate(hosts): - base_url = "%s://%s:%d" % (self.preferred_scheme, - host, - self.preferred_port) - url = urljoin(base_url, path) - - request = self.__client.build_request( - method=method, - url=url, - content=body, - params=params, - headers=all_headers, - timeout=timeout, - ) - try: - response = self.__client.send(request) - except Exception as e: - # if last try or cumulative timeout is done, throw exception up - time_passed = time.time() - requested_at - if retry_count == len(hosts) - 1 or time_passed > http_max_retry_duration: - raise e - else: - try: - if raise_on_error: - AblyException.raise_for_response(response) - - # Keep fallback host for later (RSC15f) - if retry_count > 0 and host != self.options.get_rest_host(): - self.__host = host - self.__host_expires = time.time() + (self.options.fallback_retry_timeout / 1000.0) - - return Response(response) - except AblyException as e: - if not e.is_server_error: - raise e - - # if last try or cumulative timeout is done, throw exception up - time_passed = time.time() - requested_at - if retry_count == len(hosts) - 1 or time_passed > http_max_retry_duration: - raise e - - def delete(self, url, headers=None, skip_auth=False, timeout=None): - result = self.make_request('DELETE', url, headers=headers, - skip_auth=skip_auth, timeout=timeout) - return result - - def get(self, url, headers=None, skip_auth=False, timeout=None): - result = self.make_request('GET', url, headers=headers, - skip_auth=skip_auth, timeout=timeout) - return result - - def patch(self, url, headers=None, body=None, skip_auth=False, timeout=None): - result = self.make_request('PATCH', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) - return result - - def post(self, url, headers=None, body=None, skip_auth=False, timeout=None): - result = self.make_request('POST', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) - return result - - def put(self, url, headers=None, body=None, skip_auth=False, timeout=None): - result = self.make_request('PUT', url, headers=headers, body=body, - skip_auth=skip_auth, timeout=timeout) - return result - - @property - def auth(self): - return self.__auth - - @auth.setter - def auth(self, value): - self.__auth = value - - @property - def options(self): - return self.__options - - @property - def preferred_host(self): - return self.options.get_rest_host() - - @property - def preferred_port(self): - return Defaults.get_port(self.options) - - @property - def preferred_scheme(self): - return Defaults.get_scheme(self.options) - - @property - def http_open_timeout(self): - if self.options.http_open_timeout is not None: - return self.options.http_open_timeout - return self.CONNECTION_RETRY_DEFAULTS['http_open_timeout'] - - @property - def http_request_timeout(self): - if self.options.http_request_timeout is not None: - return self.options.http_request_timeout - return self.CONNECTION_RETRY_DEFAULTS['http_request_timeout'] - - @property - def http_max_retry_count(self): - if self.options.http_max_retry_count is not None: - return self.options.http_max_retry_count - return self.CONNECTION_RETRY_DEFAULTS['http_max_retry_count'] - - @property - def http_max_retry_duration(self): - if self.options.http_max_retry_duration is not None: - return self.options.http_max_retry_duration - return self.CONNECTION_RETRY_DEFAULTS['http_max_retry_duration'] diff --git a/ably/sync/http/httputils.py b/ably/sync/http/httputils.py deleted file mode 100644 index b55ae75c..00000000 --- a/ably/sync/http/httputils.py +++ /dev/null @@ -1,55 +0,0 @@ -import base64 -import os -import platform - -import ably - - -class HttpUtils: - default_format = "json" - - mime_types = { - "json": "application/json", - "xml": "application/xml", - "html": "text/html", - "binary": "application/x-msgpack", - } - - @staticmethod - def default_get_headers(binary=False, version=None): - headers = HttpUtils.default_headers(version=version) - if binary: - headers["Accept"] = HttpUtils.mime_types['binary'] - else: - headers["Accept"] = HttpUtils.mime_types['json'] - return headers - - @staticmethod - def default_post_headers(binary=False, version=None): - headers = HttpUtils.default_get_headers(binary=binary, version=version) - headers["Content-Type"] = headers["Accept"] - return headers - - @staticmethod - def get_host_header(host): - return { - 'Host': host, - } - - @staticmethod - def default_headers(version=None): - if version is None: - version = ably.api_version - return { - "X-Ably-Version": version, - "Ably-Agent": 'ably-python/%s python/%s' % (ably.lib_version, platform.python_version()) - } - - @staticmethod - def get_query_params(options): - params = {} - - if options.add_request_ids: - params['request_id'] = base64.urlsafe_b64encode(os.urandom(12)).decode('ascii') - - return params diff --git a/ably/sync/http/paginatedresult.py b/ably/sync/http/paginatedresult.py deleted file mode 100644 index 663baad9..00000000 --- a/ably/sync/http/paginatedresult.py +++ /dev/null @@ -1,134 +0,0 @@ -import calendar -import logging -from urllib.parse import urlencode - -from ably.sync.http.http import Request -from ably.sync.util import case - -log = logging.getLogger(__name__) - - -def format_time_param(t): - try: - return '%d' % (calendar.timegm(t.utctimetuple()) * 1000) - except Exception: - return str(t) - - -def format_params(params=None, direction=None, start=None, end=None, limit=None, **kw): - if params is None: - params = {} - - for key, value in kw.items(): - if value is not None: - key = case.snake_to_camel(key) - params[key] = value - - if direction: - params['direction'] = str(direction) - if start: - params['start'] = format_time_param(start) - if end: - params['end'] = format_time_param(end) - if limit: - if limit > 1000: - raise ValueError("The maximum allowed limit is 1000") - params['limit'] = '%d' % limit - - if 'start' in params and 'end' in params and params['start'] > params['end']: - raise ValueError("'end' parameter has to be greater than or equal to 'start'") - - return '?' + urlencode(params) if params else '' - - -class PaginatedResultSync: - def __init__(self, http, items, content_type, rel_first, rel_next, - response_processor, response): - self.__http = http - self.__items = items - self.__content_type = content_type - self.__rel_first = rel_first - self.__rel_next = rel_next - self.__response_processor = response_processor - self.response = response - - @property - def items(self): - return self.__items - - def has_first(self): - return self.__rel_first is not None - - def has_next(self): - return self.__rel_next is not None - - def is_last(self): - return not self.has_next() - - def first(self): - return self.__get_rel(self.__rel_first) if self.__rel_first else None - - def next(self): - return self.__get_rel(self.__rel_next) if self.__rel_next else None - - def __get_rel(self, rel_req): - if rel_req is None: - return None - return self.paginated_query_with_request(self.__http, rel_req, self.__response_processor) - - @classmethod - def paginated_query(cls, http, method='GET', url='/', version=None, body=None, - headers=None, response_processor=None, - raise_on_error=True): - headers = headers or {} - req = Request(method, url, version=version, body=body, headers=headers, skip_auth=False, - raise_on_error=raise_on_error) - return cls.paginated_query_with_request(http, req, response_processor) - - @classmethod - def paginated_query_with_request(cls, http, request, response_processor, - raise_on_error=True): - response = http.make_request( - request.method, request.url, version=request.version, - headers=request.headers, body=request.body, - skip_auth=request.skip_auth, raise_on_error=request.raise_on_error) - - items = response_processor(response) - - content_type = response.headers['Content-Type'] - links = response.links - if 'first' in links: - first_rel_request = request.with_relative_url(links['first']['url']) - else: - first_rel_request = None - - if 'next' in links: - next_rel_request = request.with_relative_url(links['next']['url']) - else: - next_rel_request = None - - return cls(http, items, content_type, first_rel_request, - next_rel_request, response_processor, response) - - -class HttpPaginatedResponseSync(PaginatedResultSync): - @property - def status_code(self): - return self.response.status_code - - @property - def success(self): - status_code = self.status_code - return 200 <= status_code < 300 - - @property - def error_code(self): - return self.response.headers.get('X-Ably-Errorcode') - - @property - def error_message(self): - return self.response.headers.get('X-Ably-Errormessage') - - @property - def headers(self): - return list(self.response.headers.items()) diff --git a/ably/sync/realtime/__init__.py b/ably/sync/realtime/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/realtime/connection.py b/ably/sync/realtime/connection.py deleted file mode 100644 index 9cf046ff..00000000 --- a/ably/sync/realtime/connection.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations -import functools -import logging -from ably.sync.realtime.connectionmanager import ConnectionManager -from ably.sync.types.connectiondetails import ConnectionDetails -from ably.sync.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange -from ably.sync.util.eventemitter import EventEmitter -from ably.sync.util.exceptions import AblyException -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from ably.sync.realtime.realtime import AblyRealtime - -log = logging.getLogger(__name__) - - -class Connection(EventEmitter): # RTN4 - """Ably Realtime Connection - - Enables the management of a connection to Ably - - Attributes - ---------- - state: str - Connection state - error_reason: ErrorInfo - An ErrorInfo object describing the last error which occurred on the channel, if any. - - - Methods - ------- - connect() - Establishes a realtime connection - close() - Closes a realtime connection - ping() - Pings a realtime connection - """ - - def __init__(self, realtime: AblyRealtime): - self.__realtime = realtime - self.__error_reason: Optional[AblyException] = None - self.__state = ConnectionState.CONNECTING if realtime.options.auto_connect else ConnectionState.INITIALIZED - self.__connection_manager = ConnectionManager(self.__realtime, self.state) - self.__connection_manager.on('connectionstate', self._on_state_update) # RTN4a - self.__connection_manager.on('update', self._on_connection_update) # RTN4h - super().__init__() - - # RTN11 - def connect(self) -> None: - """Establishes a realtime connection. - - Causes the connection to open, entering the connecting state - """ - self.__error_reason = None - self.connection_manager.request_state(ConnectionState.CONNECTING) - - def close(self) -> None: - """Causes the connection to close, entering the closing state. - - Once closed, the library will not attempt to re-establish the - connection without an explicit call to connect() - """ - self.connection_manager.request_state(ConnectionState.CLOSING) - self.once_async(ConnectionState.CLOSED) - - # RTN13 - def ping(self) -> float: - """Send a ping to the realtime connection - - When connected, sends a heartbeat ping to the Ably server and executes - the callback with any error and the response time in milliseconds when - a heartbeat ping request is echoed from the server. - - Raises - ------ - AblyException - If ping request cannot be sent due to invalid state - - Returns - ------- - float - The response time in milliseconds - """ - return self.__connection_manager.ping() - - def _on_state_update(self, state_change: ConnectionStateChange) -> None: - log.info(f'Connection state changing from {self.state} to {state_change.current}') - self.__state = state_change.current - if state_change.reason is not None: - self.__error_reason = state_change.reason - self.__realtime.options.loop.call_soon(functools.partial(self._emit, state_change.current, state_change)) - - def _on_connection_update(self, state_change: ConnectionStateChange) -> None: - self.__realtime.options.loop.call_soon(functools.partial(self._emit, ConnectionEvent.UPDATE, state_change)) - - # RTN4d - @property - def state(self) -> ConnectionState: - """The current connection state of the connection""" - return self.__state - - # RTN25 - @property - def error_reason(self) -> Optional[AblyException]: - """An object describing the last error which occurred on the channel, if any.""" - return self.__error_reason - - @state.setter - def state(self, value: ConnectionState) -> None: - self.__state = value - - @property - def connection_manager(self) -> ConnectionManager: - return self.__connection_manager - - @property - def connection_details(self) -> Optional[ConnectionDetails]: - return self.__connection_manager.connection_details diff --git a/ably/sync/realtime/connectionmanager.py b/ably/sync/realtime/connectionmanager.py deleted file mode 100644 index 7e5fd820..00000000 --- a/ably/sync/realtime/connectionmanager.py +++ /dev/null @@ -1,524 +0,0 @@ -from __future__ import annotations -import logging -import asyncio -import httpx -from ably.sync.transport.websockettransport import WebSocketTransport, ProtocolMessageAction -from ably.sync.transport.defaults import Defaults -from ably.sync.types.connectionerrors import ConnectionErrors -from ably.sync.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange -from ably.sync.types.tokendetails import TokenDetails -from ably.sync.util.exceptions import AblyException, IncompatibleClientIdException -from ably.sync.util.eventemitter import EventEmitter -from datetime import datetime -from ably.sync.util.helper import get_random_id, Timer, is_token_error -from typing import Optional, TYPE_CHECKING -from ably.sync.types.connectiondetails import ConnectionDetails -from queue import Queue - -if TYPE_CHECKING: - from ably.sync.realtime.realtime import AblyRealtime - -log = logging.getLogger(__name__) - - -class ConnectionManager(EventEmitter): - def __init__(self, realtime: AblyRealtime, initial_state): - self.options = realtime.options - self.__ably = realtime - self.__state: ConnectionState = initial_state - self.__ping_future: Optional[asyncio.Future] = None - self.__timeout_in_secs: float = self.options.realtime_request_timeout / 1000 - self.transport: Optional[WebSocketTransport] = None - self.__connection_details: Optional[ConnectionDetails] = None - self.connection_id: Optional[str] = None - self.__fail_state = ConnectionState.DISCONNECTED - self.transition_timer: Optional[Timer] = None - self.suspend_timer: Optional[Timer] = None - self.retry_timer: Optional[Timer] = None - self.connect_base_task: Optional[asyncio.Task] = None - self.disconnect_transport_task: Optional[asyncio.Task] = None - self.__fallback_hosts: list[str] = self.options.get_fallback_realtime_hosts() - self.queued_messages: Queue = Queue() - self.__error_reason: Optional[AblyException] = None - super().__init__() - - def enact_state_change(self, state: ConnectionState, reason: Optional[AblyException] = None) -> None: - current_state = self.__state - log.debug(f'ConnectionManager.enact_state_change(): {current_state} -> {state}; reason = {reason}') - self.__state = state - if reason: - self.__error_reason = reason - self._emit('connectionstate', ConnectionStateChange(current_state, state, state, reason)) - - def check_connection(self) -> bool: - try: - response = httpx.get(self.options.connectivity_check_url) - return 200 <= response.status_code < 300 and \ - (self.options.connectivity_check_url != Defaults.connectivity_check_url or "yes" in response.text) - except httpx.HTTPError: - return False - - def get_state_error(self) -> AblyException: - return ConnectionErrors[self.state] - - def __get_transport_params(self) -> dict: - protocol_version = Defaults.protocol_version - params = self.ably.auth.get_auth_transport_param() - params["v"] = protocol_version - if self.connection_details: - params["resume"] = self.connection_details.connection_key - return params - - def close_impl(self) -> None: - log.debug('ConnectionManager.close_impl()') - - self.cancel_suspend_timer() - self.start_transition_timer(ConnectionState.CLOSING, fail_state=ConnectionState.CLOSED) - if self.transport: - self.transport.dispose() - if self.connect_base_task: - self.connect_base_task.cancel() - if self.disconnect_transport_task: - self.disconnect_transport_task - self.cancel_retry_timer() - - self.notify_state(ConnectionState.CLOSED) - - def send_protocol_message(self, protocol_message: dict) -> None: - if self.state in ( - ConnectionState.DISCONNECTED, - ConnectionState.CONNECTING, - ): - self.queued_messages.put(protocol_message) - return - - if self.state == ConnectionState.CONNECTED: - if self.transport: - self.transport.send(protocol_message) - else: - log.exception( - "ConnectionManager.send_protocol_message(): can not send message with no active transport" - ) - return - - raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000) - - def send_queued_messages(self) -> None: - log.info(f'ConnectionManager.send_queued_messages(): sending {self.queued_messages.qsize()} message(s)') - while not self.queued_messages.empty(): - asyncio.create_task(self.send_protocol_message(self.queued_messages.get())) - - def fail_queued_messages(self, err) -> None: - log.info( - f"ConnectionManager.fail_queued_messages(): discarding {self.queued_messages.qsize()} messages;" + - f" reason = {err}" - ) - while not self.queued_messages.empty(): - msg = self.queued_messages.get() - log.exception(f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: {msg}") - - def ping(self) -> float: - if self.__ping_future: - try: - response = self.__ping_future - except asyncio.CancelledError: - raise AblyException("Ping request cancelled due to request timeout", 504, 50003) - return response - - self.__ping_future = asyncio.Future() - if self.__state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]: - self.__ping_id = get_random_id() - ping_start_time = datetime.now().timestamp() - self.send_protocol_message({"action": ProtocolMessageAction.HEARTBEAT, - "id": self.__ping_id}) - else: - raise AblyException("Cannot send ping request. Calling ping in invalid state", 40000, 400) - try: - asyncio.wait_for(self.__ping_future, self.__timeout_in_secs) - except asyncio.TimeoutError: - raise AblyException("Timeout waiting for ping response", 504, 50003) - - ping_end_time = datetime.now().timestamp() - response_time_ms = (ping_end_time - ping_start_time) * 1000 - return round(response_time_ms, 2) - - def on_connected(self, connection_details: ConnectionDetails, connection_id: str, - reason: Optional[AblyException] = None) -> None: - self.__fail_state = ConnectionState.DISCONNECTED - - self.__connection_details = connection_details - self.connection_id = connection_id - - if connection_details.client_id: - try: - self.ably.auth._configure_client_id(connection_details.client_id) - except IncompatibleClientIdException as e: - self.notify_state(ConnectionState.FAILED, reason=e) - return - - if self.__state == ConnectionState.CONNECTED: - state_change = ConnectionStateChange(ConnectionState.CONNECTED, ConnectionState.CONNECTED, - ConnectionEvent.UPDATE) - self._emit(ConnectionEvent.UPDATE, state_change) - else: - self.notify_state(ConnectionState.CONNECTED, reason=reason) - - self.ably.channels._on_connected() - - def on_disconnected(self, exception: AblyException) -> None: - # RTN15h - if self.transport: - self.transport.dispose() - if exception: - status_code = exception.status_code - if status_code >= 500 and status_code <= 504: # RTN17f1 - if len(self.__fallback_hosts) > 0: - try: - self.connect_with_fallback_hosts(self.__fallback_hosts) - except Exception as e: - self.notify_state(self.__fail_state, reason=e) - return - else: - log.info("No fallback host to try for disconnected protocol message") - elif is_token_error(exception): - self.on_token_error(exception) - else: - self.notify_state(ConnectionState.DISCONNECTED, exception) - else: - log.warn("DISCONNECTED message received without error") - - def on_token_error(self, exception: AblyException) -> None: - if self.__error_reason is None or not is_token_error(self.__error_reason): - self.__error_reason = exception - try: - self.ably.auth._ensure_valid_auth_credentials(force=True) - except Exception as e: - self.on_error_from_authorize(e) - return - self.notify_state(self.__fail_state, exception, retry_immediately=True) - return - self.notify_state(self.__fail_state, exception) - - def on_error(self, msg: dict, exception: AblyException) -> None: - if msg.get("channel") is not None: # RTN15i - self.on_channel_message(msg) - return - if self.transport: - self.transport.dispose() - if is_token_error(exception): # RTN14b - self.on_token_error(exception) - else: - self.enact_state_change(ConnectionState.FAILED, exception) - - def on_error_from_authorize(self, exception: AblyException) -> None: - log.info("ConnectionManager.on_error_from_authorize(): err = %s", exception) - # RSA4a - if exception.code == 40171: - self.notify_state(ConnectionState.FAILED, exception) - elif exception.status_code == 403: - msg = 'Client configured authentication provider returned 403; failing the connection' - log.error(f'ConnectionManager.on_error_from_authorize(): {msg}') - self.notify_state(ConnectionState.FAILED, AblyException(msg, 403, 80019)) - else: - msg = 'Client configured authentication provider request failed' - log.warning(f'ConnectionManager.on_error_from_authorize: {msg}') - self.notify_state(self.__fail_state, AblyException(msg, 401, 80019)) - - def on_closed(self) -> None: - if self.transport: - self.transport.dispose() - if self.connect_base_task: - self.connect_base_task.cancel() - - def on_channel_message(self, msg: dict) -> None: - self.__ably.channels._on_channel_message(msg) - - def on_heartbeat(self, id: Optional[str]) -> None: - if self.__ping_future: - # Resolve on heartbeat from ping request. - if self.__ping_id == id: - if not self.__ping_future.cancelled(): - self.__ping_future.set_result(None) - self.__ping_future = None - - def deactivate_transport(self, reason: Optional[AblyException] = None): - self.transport = None - self.notify_state(ConnectionState.DISCONNECTED, reason) - - def request_state(self, state: ConnectionState, force=False) -> None: - log.debug(f'ConnectionManager.request_state(): state = {state}') - - if not force and state == self.state: - return - - if state == ConnectionState.CONNECTING and self.__state == ConnectionState.CONNECTED: - return - - if state == ConnectionState.CLOSING and self.__state == ConnectionState.CLOSED: - return - - if state == ConnectionState.CONNECTING and self.__state in (ConnectionState.CLOSED, - ConnectionState.FAILED): - self.ably.channels._initialize_channels() - - if not force: - self.enact_state_change(state) - - if state == ConnectionState.CONNECTING: - self.start_connect() - - if state == ConnectionState.CLOSING: - asyncio.create_task(self.close_impl()) - - def start_connect(self) -> None: - self.start_suspend_timer() - self.start_transition_timer(ConnectionState.CONNECTING) - self.connect_base_task = asyncio.create_task(self.connect_base()) - - def connect_with_fallback_hosts(self, fallback_hosts: list) -> Optional[Exception]: - for host in fallback_hosts: - try: - if self.check_connection(): - self.try_host(host) - return - else: - message = "Unable to connect, network unreachable" - log.exception(message) - exception = AblyException(message, status_code=404, code=80003) - self.notify_state(self.__fail_state, exception) - return - except Exception as exc: - exception = exc - log.exception(f'Connection to {host} failed, reason={exception}') - log.exception("No more fallback hosts to try") - return exception - - def connect_base(self) -> None: - fallback_hosts = self.__fallback_hosts - primary_host = self.options.get_realtime_host() - try: - self.try_host(primary_host) - return - except Exception as exception: - log.exception(f'Connection to {primary_host} failed, reason={exception}') - if len(fallback_hosts) > 0: - log.info("Attempting connection to fallback host(s)") - resp = self.connect_with_fallback_hosts(fallback_hosts) - if not resp: - return - exception = resp - self.notify_state(self.__fail_state, reason=exception) - - def try_host(self, host) -> None: - try: - params = self.__get_transport_params() - except AblyException as e: - self.on_error_from_authorize(e) - return - self.transport = WebSocketTransport(self, host, params) - self._emit('transport.pending', self.transport) - self.transport.connect() - - future = asyncio.Future() - - def on_transport_connected(): - log.debug('ConnectionManager.try_a_host(): transport connected') - if self.transport: - self.transport.off('failed', on_transport_failed) - if not future.done(): - future.set_result(None) - - def on_transport_failed(exception): - log.info('ConnectionManager.try_a_host(): transport failed') - if self.transport: - self.transport.off('connected', on_transport_connected) - self.transport.dispose() - future.set_exception(exception) - - self.transport.once('connected', on_transport_connected) - self.transport.once('failed', on_transport_failed) - # Fix asyncio CancelledError in python 3.7 - try: - future - except asyncio.CancelledError: - return - - def notify_state(self, state: ConnectionState, reason: Optional[AblyException] = None, - retry_immediately: Optional[bool] = None) -> None: - # RTN15a - retry_immediately = (retry_immediately is not False) and ( - state == ConnectionState.DISCONNECTED and self.__state == ConnectionState.CONNECTED) - - log.debug( - f'ConnectionManager.notify_state(): new state: {state}' - + ('; will retry immediately' if retry_immediately else '') - ) - - if state == self.__state: - return - - self.cancel_transition_timer() - self.check_suspend_timer(state) - - if retry_immediately: - self.options.loop.call_soon(self.request_state, ConnectionState.CONNECTING) - elif state == ConnectionState.DISCONNECTED: - self.start_retry_timer(self.options.disconnected_retry_timeout) - elif state == ConnectionState.SUSPENDED: - self.start_retry_timer(self.options.suspended_retry_timeout) - - if (state == ConnectionState.DISCONNECTED and not retry_immediately) or state == ConnectionState.SUSPENDED: - self.disconnect_transport() - - self.enact_state_change(state, reason) - - if state == ConnectionState.CONNECTED: - self.send_queued_messages() - elif state in ( - ConnectionState.CLOSING, - ConnectionState.CLOSED, - ConnectionState.SUSPENDED, - ConnectionState.FAILED, - ): - self.fail_queued_messages(reason) - self.ably.channels._propagate_connection_interruption(state, reason) - - def start_transition_timer(self, state: ConnectionState, fail_state: Optional[ConnectionState] = None) -> None: - log.debug(f'ConnectionManager.start_transition_timer(): transition state = {state}') - - if self.transition_timer: - log.debug('ConnectionManager.start_transition_timer(): clearing already-running timer') - self.transition_timer.cancel() - - if fail_state is None: - fail_state = self.__fail_state if state != ConnectionState.CLOSING else ConnectionState.CLOSED - - timeout = self.options.realtime_request_timeout - - def on_transition_timer_expire(): - if self.transition_timer: - self.transition_timer = None - log.info(f'ConnectionManager {state} timer expired, notifying new state: {fail_state}') - self.notify_state( - fail_state, - AblyException("Connection cancelled due to request timeout", 504, 50003) - ) - - log.debug(f'ConnectionManager.start_transition_timer(): setting timer for {timeout}ms') - - self.transition_timer = Timer(timeout, on_transition_timer_expire) - - def cancel_transition_timer(self): - log.debug('ConnectionManager.cancel_transition_timer()') - if self.transition_timer: - self.transition_timer.cancel() - self.transition_timer = None - - def start_suspend_timer(self) -> None: - log.debug('ConnectionManager.start_suspend_timer()') - if self.suspend_timer: - return - - def on_suspend_timer_expire() -> None: - if self.suspend_timer: - self.suspend_timer = None - log.info('ConnectionManager suspend timer expired, requesting new state: suspended') - self.notify_state( - ConnectionState.SUSPENDED, - AblyException("Connection to server unavailable", 400, 80002) - ) - self.__fail_state = ConnectionState.SUSPENDED - self.__connection_details = None - - self.suspend_timer = Timer(Defaults.connection_state_ttl, on_suspend_timer_expire) - - def check_suspend_timer(self, state: ConnectionState) -> None: - if state not in ( - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTED, - ConnectionState.SUSPENDED, - ): - self.cancel_suspend_timer() - - def cancel_suspend_timer(self) -> None: - log.debug('ConnectionManager.cancel_suspend_timer()') - self.__fail_state = ConnectionState.DISCONNECTED - if self.suspend_timer: - self.suspend_timer.cancel() - self.suspend_timer = None - - def start_retry_timer(self, interval: int) -> None: - def on_retry_timeout(): - log.info('ConnectionManager retry timer expired, retrying') - self.retry_timer = None - self.request_state(ConnectionState.CONNECTING) - - self.retry_timer = Timer(interval, on_retry_timeout) - - def cancel_retry_timer(self) -> None: - if self.retry_timer: - self.retry_timer.cancel() - self.retry_timer = None - - def disconnect_transport(self) -> None: - log.info('ConnectionManager.disconnect_transport()') - if self.transport: - self.disconnect_transport_task = asyncio.create_task(self.transport.dispose()) - - def on_auth_updated(self, token_details: TokenDetails): - log.info(f"ConnectionManager.on_auth_updated(): state = {self.state}") - if self.state == ConnectionState.CONNECTED: - auth_message = { - "action": ProtocolMessageAction.AUTH, - "auth": { - "accessToken": token_details.token - } - } - self.send_protocol_message(auth_message) - - state_change = self.once_async() - - if state_change.current == ConnectionState.CONNECTED: - return - elif state_change.current == ConnectionState.FAILED: - raise state_change.reason - elif self.state == ConnectionState.CONNECTING: - if self.connect_base_task and not self.connect_base_task.done(): - self.connect_base_task.cancel() - if self.transport: - self.transport.dispose() - if self.state != ConnectionState.CONNECTED: - future = asyncio.Future() - - def on_state_change(state_change: ConnectionStateChange) -> None: - if state_change.current == ConnectionState.CONNECTED: - self.off('connectionstate', on_state_change) - future.set_result(token_details) - if state_change.current in ( - ConnectionState.CLOSED, - ConnectionState.FAILED, - ConnectionState.SUSPENDED - ): - self.off('connectionstate', on_state_change) - future.set_exception(state_change.reason or self.get_state_error()) - - self.on('connectionstate', on_state_change) - - if self.state == ConnectionState.CONNECTING: - self.start_connect() - else: - self.request_state(ConnectionState.CONNECTING) - - return future - - @property - def ably(self): - return self.__ably - - @property - def state(self) -> ConnectionState: - return self.__state - - @property - def connection_details(self) -> Optional[ConnectionDetails]: - return self.__connection_details diff --git a/ably/sync/realtime/realtime.py b/ably/sync/realtime/realtime.py deleted file mode 100644 index 517d9676..00000000 --- a/ably/sync/realtime/realtime.py +++ /dev/null @@ -1,140 +0,0 @@ -import logging -import asyncio -from typing import Optional -from ably.sync.realtime.realtime_channel import ChannelsSync -from ably.sync.realtime.connection import Connection, ConnectionState -from ably.sync.rest.rest import AblyRestSync - - -log = logging.getLogger(__name__) - - -class AblyRealtime(AblyRestSync): - """ - Ably Realtime Client - - Attributes - ---------- - loop: AbstractEventLoop - asyncio running event loop - auth: Auth - authentication object - options: Options - auth options object - connection: Connection - realtime connection object - channels: Channels - realtime channel object - - Methods - ------- - connect() - Establishes the realtime connection - close() - Closes the realtime connection - """ - - def __init__(self, key: Optional[str] = None, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs): - """Constructs a RealtimeClient object using an Ably API key. - - Parameters - ---------- - key: str - A valid ably API key string - loop: AbstractEventLoop, optional - asyncio running event loop - auto_connect: bool - When true, the client connects to Ably as soon as it is instantiated. - You can set this to false and explicitly connect to Ably using the - connect() method. The default is true. - **kwargs: client options - realtime_host: str - Enables a non-default Ably host to be specified for realtime connections. - For development environments only. The default value is realtime.ably.io. - environment: str - Enables a custom environment to be used with the Ably service. Defaults to `production` - realtime_request_timeout: float - Timeout (in milliseconds) for the wait of acknowledgement for operations performed via a realtime - connection. Operations include establishing a connection with Ably, or sending a HEARTBEAT, - CONNECT, ATTACH, DETACH or CLOSE request. The default is 10 seconds(10000 milliseconds). - disconnected_retry_timeout: float - If the connection is still in the DISCONNECTED state after this delay, the client library will - attempt to reconnect automatically. The default is 15 seconds. - channel_retry_timeout: float - When a channel becomes SUSPENDED following a server initiated DETACHED, after this delay, if the - channel is still SUSPENDED and the connection is in CONNECTED, the client library will attempt to - re-attach the channel automatically. The default is 15 seconds. - fallback_hosts: list[str] - An array of fallback hosts to be used in the case of an error necessitating the use of an - alternative host. If you have been provided a set of custom fallback hosts by Ably, please specify - them here. - connection_state_ttl: float - The duration that Ably will persist the connection state for when a Realtime client is abruptly - disconnected. - suspended_retry_timeout: float - When the connection enters the SUSPENDED state, after this delay, if the state is still SUSPENDED, - the client library attempts to reconnect automatically. The default is 30 seconds. - connectivity_check_url: string - Override the URL used by the realtime client to check if the internet is available. - In the event of a failure to connect to the primary endpoint, the client will send a - GET request to this URL to check if the internet is available. If this request returns - a success response the client will attempt to connect to a fallback host. - Raises - ------ - ValueError - If no authentication key is not provided - """ - - if loop is None: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - log.warning('Realtime client created outside event loop') - - self._is_realtime: bool = True - - # RTC1 - super().__init__(key, loop=loop, **kwargs) - - self.key = key - self.__connection = Connection(self) - self.__channels = ChannelsSync(self) - - # RTN3 - if self.options.auto_connect: - self.connection.connection_manager.request_state(ConnectionState.CONNECTING, force=True) - - # RTC15 - def connect(self) -> None: - """Establishes a realtime connection. - - Explicitly calling connect() is unnecessary unless the autoConnect attribute of the ClientOptions object - is false. Unless already connected or connecting, this method causes the connection to open, entering the - CONNECTING state. - """ - log.info('Realtime.connect() called') - # RTC15a - self.connection.connect() - - # RTC16 - def close(self) -> None: - """Causes the connection to close, entering the closing state. - Once closed, the library will not attempt to re-establish the - connection without an explicit call to connect() - """ - log.info('Realtime.close() called') - # RTC16a - self.connection.close() - super().close() - - # RTC2 - @property - def connection(self) -> Connection: - """Returns the realtime connection object""" - return self.__connection - - # RTC3, RTS1 - @property - def channels(self) -> ChannelsSync: - """Returns the realtime channel object""" - return self.__channels diff --git a/ably/sync/realtime/realtime_channel.py b/ably/sync/realtime/realtime_channel.py deleted file mode 100644 index 805244df..00000000 --- a/ably/sync/realtime/realtime_channel.py +++ /dev/null @@ -1,553 +0,0 @@ -from __future__ import annotations -import asyncio -import logging -from typing import Optional, TYPE_CHECKING -from ably.sync.realtime.connection import ConnectionState -from ably.sync.transport.websockettransport import ProtocolMessageAction -from ably.sync.rest.channel import ChannelSync, ChannelsSync as RestChannels -from ably.sync.types.channelstate import ChannelState, ChannelStateChange -from ably.sync.types.flags import Flag, has_flag -from ably.sync.types.message import Message -from ably.sync.util.eventemitter import EventEmitter -from ably.sync.util.exceptions import AblyException -from ably.sync.util.helper import Timer, is_callable_or_coroutine - -if TYPE_CHECKING: - from ably.sync.realtime.realtime import AblyRealtime - -log = logging.getLogger(__name__) - - -class RealtimeChannel(EventEmitter, ChannelSync): - """ - Ably Realtime Channel - - Attributes - ---------- - name: str - Channel name - state: str - Channel state - error_reason: AblyException - An AblyException instance describing the last error which occurred on the channel, if any. - - Methods - ------- - attach() - Attach to channel - detach() - Detach from channel - subscribe(*args) - Subscribe to messages on a channel - unsubscribe(*args) - Unsubscribe to messages from a channel - """ - - def __init__(self, realtime: AblyRealtime, name: str): - EventEmitter.__init__(self) - self.__name = name - self.__realtime = realtime - self.__state = ChannelState.INITIALIZED - self.__message_emitter = EventEmitter() - self.__state_timer: Optional[Timer] = None - self.__attach_resume = False - self.__channel_serial: Optional[str] = None - self.__retry_timer: Optional[Timer] = None - self.__error_reason: Optional[AblyException] = None - - # Used to listen to state changes internally, if we use the public event emitter interface then internals - # will be disrupted if the user called .off() to remove all listeners - self.__internal_state_emitter = EventEmitter() - - ChannelSync.__init__(self, realtime, name, {}) - - # RTL4 - def attach(self) -> None: - """Attach to channel - - Attach to this channel ensuring the channel is created in the Ably system and all messages published - on the channel are received by any channel listeners registered using subscribe - - Raises - ------ - AblyException - If unable to attach channel - """ - - log.info(f'RealtimeChannel.attach() called, channel = {self.name}') - - # RTL4a - if channel is attached do nothing - if self.state == ChannelState.ATTACHED: - return - - self.__error_reason = None - - # RTL4b - if self.__realtime.connection.state not in [ - ConnectionState.CONNECTING, - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED - ]: - raise AblyException( - message=f"Unable to attach; channel state = {self.state}", - code=90001, - status_code=400 - ) - - if self.state != ChannelState.ATTACHING: - self._request_state(ChannelState.ATTACHING) - - state_change = self.__internal_state_emitter.once_async() - - if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED): - raise state_change.reason - - def _attach_impl(self): - log.debug("RealtimeChannel.attach_impl(): sending ATTACH protocol message") - - # RTL4c - attach_msg = { - "action": ProtocolMessageAction.ATTACH, - "channel": self.name, - } - - if self.__attach_resume: - attach_msg["flags"] = Flag.ATTACH_RESUME - if self.__channel_serial: - attach_msg["channelSerial"] = self.__channel_serial - - self._send_message(attach_msg) - - # RTL5 - def detach(self) -> None: - """Detach from channel - - Any resulting channel state change is emitted to any listeners registered - Once all clients globally have detached from the channel, the channel will be released - in the Ably service within two minutes. - - Raises - ------ - AblyException - If unable to detach channel - """ - - log.info(f'RealtimeChannel.detach() called, channel = {self.name}') - - # RTL5g, RTL5b - raise exception if state invalid - if self.__realtime.connection.state in [ConnectionState.CLOSING, ConnectionState.FAILED]: - raise AblyException( - message=f"Unable to detach; channel state = {self.state}", - code=90001, - status_code=400 - ) - - # RTL5a - if channel already detached do nothing - if self.state in [ChannelState.INITIALIZED, ChannelState.DETACHED]: - return - - if self.state == ChannelState.SUSPENDED: - self._notify_state(ChannelState.DETACHED) - return - elif self.state == ChannelState.FAILED: - raise AblyException("Unable to detach; channel state = failed", 90001, 400) - else: - self._request_state(ChannelState.DETACHING) - - # RTL5h - wait for pending connection - if self.__realtime.connection.state == ConnectionState.CONNECTING: - self.__realtime.connect() - - state_change = self.__internal_state_emitter.once_async() - new_state = state_change.current - - if new_state == ChannelState.DETACHED: - return - elif new_state == ChannelState.ATTACHING: - raise AblyException("Detach request superseded by a subsequent attach request", 90000, 409) - else: - raise state_change.reason - - def _detach_impl(self) -> None: - log.debug("RealtimeChannel.detach_impl(): sending DETACH protocol message") - - # RTL5d - detach_msg = { - "action": ProtocolMessageAction.DETACH, - "channel": self.__name, - } - - self._send_message(detach_msg) - - # RTL7 - def subscribe(self, *args) -> None: - """Subscribe to a channel - - Registers a listener for messages on the channel. - The caller supplies a listener function, which is called - each time one or more messages arrives on the channel. - - The function resolves once the channel is attached. - - Parameters - ---------- - *args: event, listener - Subscribe event and listener - - arg1(event): str, optional - Subscribe to messages with the given event name - - arg2(listener): callable - Subscribe to all messages on the channel - - When no event is provided, arg1 is used as the listener. - - Raises - ------ - AblyException - If unable to subscribe to a channel due to invalid connection state - ValueError - If no valid subscribe arguments are passed - """ - if isinstance(args[0], str): - event = args[0] - if not args[1]: - raise ValueError("channel.subscribe called without listener") - if not is_callable_or_coroutine(args[1]): - raise ValueError("subscribe listener must be function or coroutine function") - listener = args[1] - elif is_callable_or_coroutine(args[0]): - listener = args[0] - event = None - else: - raise ValueError('invalid subscribe arguments') - - log.info(f'RealtimeChannel.subscribe called, channel = {self.name}, event = {event}') - - if event is not None: - # RTL7b - self.__message_emitter.on(event, listener) - else: - # RTL7a - self.__message_emitter.on(listener) - - # RTL7c - self.attach() - - # RTL8 - def unsubscribe(self, *args) -> None: - """Unsubscribe from a channel - - Deregister the given listener for (for any/all event names). - This removes an earlier event-specific subscription. - - Parameters - ---------- - *args: event, listener - Unsubscribe event and listener - - arg1(event): str, optional - Unsubscribe to messages with the given event name - - arg2(listener): callable - Unsubscribe to all messages on the channel - - When no event is provided, arg1 is used as the listener. - - Raises - ------ - ValueError - If no valid unsubscribe arguments are passed, no listener or listener is not a function - or coroutine - """ - if len(args) == 0: - event = None - listener = None - elif isinstance(args[0], str): - event = args[0] - if not args[1]: - raise ValueError("channel.unsubscribe called without listener") - if not is_callable_or_coroutine(args[1]): - raise ValueError("unsubscribe listener must be a function or coroutine function") - listener = args[1] - elif is_callable_or_coroutine(args[0]): - listener = args[0] - event = None - else: - raise ValueError('invalid unsubscribe arguments') - - log.info(f'RealtimeChannel.unsubscribe called, channel = {self.name}, event = {event}') - - if listener is None: - # RTL8c - self.__message_emitter.off() - elif event is not None: - # RTL8b - self.__message_emitter.off(event, listener) - else: - # RTL8a - self.__message_emitter.off(listener) - - def _on_message(self, proto_msg: dict) -> None: - action = proto_msg.get('action') - # RTL4c1 - channel_serial = proto_msg.get('channelSerial') - if channel_serial: - self.__channel_serial = channel_serial - # TM2a, TM2c, TM2f - Message.update_inner_message_fields(proto_msg) - - if action == ProtocolMessageAction.ATTACHED: - flags = proto_msg.get('flags') - error = proto_msg.get("error") - exception = None - resumed = False - - if error: - exception = AblyException.from_dict(error) - - if flags: - resumed = has_flag(flags, Flag.RESUMED) - - # RTL12 - if self.state == ChannelState.ATTACHED: - if not resumed: - state_change = ChannelStateChange(self.state, ChannelState.ATTACHED, resumed, exception) - self._emit("update", state_change) - elif self.state == ChannelState.ATTACHING: - self._notify_state(ChannelState.ATTACHED, resumed=resumed) - else: - log.warn("RealtimeChannel._on_message(): ATTACHED received while not attaching") - elif action == ProtocolMessageAction.DETACHED: - if self.state == ChannelState.DETACHING: - self._notify_state(ChannelState.DETACHED) - elif self.state == ChannelState.ATTACHING: - self._notify_state(ChannelState.SUSPENDED) - else: - self._request_state(ChannelState.ATTACHING) - elif action == ProtocolMessageAction.MESSAGE: - messages = Message.from_encoded_array(proto_msg.get('messages')) - for message in messages: - self.__message_emitter._emit(message.name, message) - elif action == ProtocolMessageAction.ERROR: - error = AblyException.from_dict(proto_msg.get('error')) - self._notify_state(ChannelState.FAILED, reason=error) - - def _request_state(self, state: ChannelState) -> None: - log.debug(f'RealtimeChannel._request_state(): state = {state}') - self._notify_state(state) - self._check_pending_state() - - def _notify_state(self, state: ChannelState, reason: Optional[AblyException] = None, - resumed: bool = False) -> None: - log.debug(f'RealtimeChannel._notify_state(): state = {state}') - - self.__clear_state_timer() - - if state == self.state: - return - - if reason is not None: - self.__error_reason = reason - - if state == ChannelState.INITIALIZED: - self.__error_reason = None - - if state == ChannelState.SUSPENDED and self.ably.connection.state == ConnectionState.CONNECTED: - self.__start_retry_timer() - else: - self.__cancel_retry_timer() - - # RTL4j1 - if state == ChannelState.ATTACHED: - self.__attach_resume = True - if state in (ChannelState.DETACHING, ChannelState.FAILED): - self.__attach_resume = False - - # RTP5a1 - if state in (ChannelState.DETACHED, ChannelState.SUSPENDED, ChannelState.FAILED): - self.__channel_serial = None - - state_change = ChannelStateChange(self.__state, state, resumed, reason=reason) - - self.__state = state - self._emit(state, state_change) - self.__internal_state_emitter._emit(state, state_change) - - def _send_message(self, msg: dict) -> None: - asyncio.create_task(self.__realtime.connection.connection_manager.send_protocol_message(msg)) - - def _check_pending_state(self): - connection_state = self.__realtime.connection.connection_manager.state - - if connection_state is not ConnectionState.CONNECTED: - log.debug(f"RealtimeChannel._check_pending_state(): connection state = {connection_state}") - return - - if self.state == ChannelState.ATTACHING: - self.__start_state_timer() - self._attach_impl() - elif self.state == ChannelState.DETACHING: - self.__start_state_timer() - self._detach_impl() - - def __start_state_timer(self) -> None: - if not self.__state_timer: - def on_timeout() -> None: - log.debug('RealtimeChannel.start_state_timer(): timer expired') - self.__state_timer = None - self.__timeout_pending_state() - - self.__state_timer = Timer(self.__realtime.options.realtime_request_timeout, on_timeout) - - def __clear_state_timer(self) -> None: - if self.__state_timer: - self.__state_timer.cancel() - self.__state_timer = None - - def __timeout_pending_state(self) -> None: - if self.state == ChannelState.ATTACHING: - self._notify_state( - ChannelState.SUSPENDED, reason=AblyException("Channel attach timed out", 408, 90007)) - elif self.state == ChannelState.DETACHING: - self._notify_state(ChannelState.ATTACHED, reason=AblyException("Channel detach timed out", 408, 90007)) - else: - self._check_pending_state() - - def __start_retry_timer(self) -> None: - if self.__retry_timer: - return - - self.__retry_timer = Timer(self.ably.options.channel_retry_timeout, self.__on_retry_timer_expire) - - def __cancel_retry_timer(self) -> None: - if self.__retry_timer: - self.__retry_timer.cancel() - self.__retry_timer = None - - def __on_retry_timer_expire(self) -> None: - if self.state == ChannelState.SUSPENDED and self.ably.connection.state == ConnectionState.CONNECTED: - self.__retry_timer = None - log.info("RealtimeChannel retry timer expired, attempting a new attach") - self._request_state(ChannelState.ATTACHING) - - # RTL23 - @property - def name(self) -> str: - """Returns channel name""" - return self.__name - - # RTL2b - @property - def state(self) -> ChannelState: - """Returns channel state""" - return self.__state - - @state.setter - def state(self, state: ChannelState) -> None: - self.__state = state - - # RTL24 - @property - def error_reason(self) -> Optional[AblyException]: - """An AblyException instance describing the last error which occurred on the channel, if any.""" - return self.__error_reason - - -class ChannelsSync(RestChannels): - """Creates and destroys RealtimeChannel objects. - - Methods - ------- - get(name) - Gets a channel - release(name) - Releases a channel - """ - - # RTS3 - def get(self, name: str) -> RealtimeChannel: - """Creates a new RealtimeChannel object, or returns the existing channel object. - - Parameters - ---------- - - name: str - Channel name - """ - if name not in self.__all: - channel = self.__all[name] = RealtimeChannel(self.__ably, name) - else: - channel = self.__all[name] - return channel - - # RTS4 - def release(self, name: str) -> None: - """Releases a RealtimeChannel object, deleting it, and enabling it to be garbage collected - - It also removes any listeners associated with the channel. - To release a channel, the channel state must be INITIALIZED, DETACHED, or FAILED. - - - Parameters - ---------- - name: str - Channel name - """ - if name not in self.__all: - return - del self.__all[name] - - def _on_channel_message(self, msg: dict) -> None: - channel_name = msg.get('channel') - if not channel_name: - log.error( - 'Channels.on_channel_message()', - f'received event without channel, action = {msg.get("action")}' - ) - return - - channel = self.__all[channel_name] - if not channel: - log.warning( - 'Channels.on_channel_message()', - f'receieved event for non-existent channel: {channel_name}' - ) - return - - channel._on_message(msg) - - def _propagate_connection_interruption(self, state: ConnectionState, reason: Optional[AblyException]) -> None: - from_channel_states = ( - ChannelState.ATTACHING, - ChannelState.ATTACHED, - ChannelState.DETACHING, - ChannelState.SUSPENDED, - ) - - connection_to_channel_state = { - ConnectionState.CLOSING: ChannelState.DETACHED, - ConnectionState.CLOSED: ChannelState.DETACHED, - ConnectionState.FAILED: ChannelState.FAILED, - ConnectionState.SUSPENDED: ChannelState.SUSPENDED, - } - - for channel_name in self.__all: - channel = self.__all[channel_name] - if channel.state in from_channel_states: - channel._notify_state(connection_to_channel_state[state], reason) - - def _on_connected(self) -> None: - for channel_name in self.__all: - channel = self.__all[channel_name] - if channel.state == ChannelState.ATTACHING or channel.state == ChannelState.DETACHING: - channel._check_pending_state() - elif channel.state == ChannelState.SUSPENDED: - asyncio.create_task(channel.attach()) - elif channel.state == ChannelState.ATTACHED: - channel._request_state(ChannelState.ATTACHING) - - def _initialize_channels(self) -> None: - for channel_name in self.__all: - channel = self.__all[channel_name] - channel._request_state(ChannelState.INITIALIZED) diff --git a/ably/sync/rest/__init__.py b/ably/sync/rest/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/rest/auth.py b/ably/sync/rest/auth.py deleted file mode 100644 index 851a2ace..00000000 --- a/ably/sync/rest/auth.py +++ /dev/null @@ -1,425 +0,0 @@ -from __future__ import annotations -import base64 -from datetime import timedelta -import logging -import time -from typing import Optional, TYPE_CHECKING, Union -import uuid -import httpx - -from ably.sync.types.options import Options -if TYPE_CHECKING: - from ably.sync.rest.rest import AblyRestSync - from ably.sync.realtime.realtime import AblyRealtime - -from ably.sync.types.capability import Capability -from ably.sync.types.tokendetails import TokenDetails -from ably.sync.types.tokenrequest import TokenRequest -from ably.sync.util.exceptions import AblyAuthException, AblyException, IncompatibleClientIdException - -__all__ = ["AuthSync"] - -log = logging.getLogger(__name__) - - -class AuthSync: - - class Method: - BASIC = "BASIC" - TOKEN = "TOKEN" - - def __init__(self, ably: Union[AblyRestSync, AblyRealtime], options: Options): - self.__ably = ably - self.__auth_options = options - - if not self.ably._is_realtime: - self.__client_id = options.client_id - if not self.__client_id and options.token_details: - self.__client_id = options.token_details.client_id - else: - self.__client_id = None - self.__client_id_validated: bool = False - - self.__basic_credentials: Optional[str] = None - self.__auth_params: Optional[dict] = None - self.__token_details: Optional[TokenDetails] = None - self.__time_offset: Optional[int] = None - - must_use_token_auth = options.use_token_auth is True - must_not_use_token_auth = options.use_token_auth is False - can_use_basic_auth = options.key_secret is not None - if not must_use_token_auth and can_use_basic_auth: - # We have the key, no need to authenticate the client - # default to using basic auth - log.debug("anonymous, using basic auth") - self.__auth_mechanism = AuthSync.Method.BASIC - basic_key = "%s:%s" % (options.key_name, options.key_secret) - basic_key = base64.b64encode(basic_key.encode('utf-8')) - self.__basic_credentials = basic_key.decode('ascii') - return - elif must_not_use_token_auth and not can_use_basic_auth: - raise ValueError('If use_token_auth is False you must provide a key') - - # Using token auth - self.__auth_mechanism = AuthSync.Method.TOKEN - - if options.token_details: - self.__token_details = options.token_details - elif options.auth_token: - self.__token_details = TokenDetails(token=options.auth_token) - else: - self.__token_details = None - - if options.auth_callback: - log.debug("using token auth with auth_callback") - elif options.auth_url: - log.debug("using token auth with auth_url") - elif options.key_secret: - log.debug("using token auth with client-side signing") - elif options.auth_token: - log.debug("using token auth with supplied token only") - elif options.token_details: - log.debug("using token auth with supplied token_details") - else: - raise ValueError("Can't authenticate via token, must provide " - "auth_callback, auth_url, key, token or a TokenDetail") - - def get_auth_transport_param(self): - auth_credentials = {} - if self.auth_options.client_id: - auth_credentials["client_id"] = self.auth_options.client_id - if self.__auth_mechanism == AuthSync.Method.BASIC: - key_name = self.__auth_options.key_name - key_secret = self.__auth_options.key_secret - auth_credentials["key"] = f"{key_name}:{key_secret}" - elif self.__auth_mechanism == AuthSync.Method.TOKEN: - token_details = self._ensure_valid_auth_credentials() - auth_credentials["accessToken"] = token_details.token - return auth_credentials - - def __authorize_when_necessary(self, token_params=None, auth_options=None, force=False): - token_details = self._ensure_valid_auth_credentials(token_params, auth_options, force) - - if self.ably._is_realtime: - self.ably.connection.connection_manager.on_auth_updated(token_details) - - return token_details - - def _ensure_valid_auth_credentials(self, token_params=None, auth_options=None, force=False): - self.__auth_mechanism = AuthSync.Method.TOKEN - if token_params is None: - token_params = dict(self.auth_options.default_token_params) - else: - self.auth_options.default_token_params = dict(token_params) - self.auth_options.default_token_params.pop('timestamp', None) - - if auth_options is not None: - self.auth_options.replace(auth_options) - auth_options = dict(self.auth_options.auth_options) - if self.client_id is not None: - token_params['client_id'] = self.client_id - - token_details = self.__token_details - if not force and not self.token_details_has_expired(): - log.debug("using cached token; expires = %d", - token_details.expires) - return token_details - - self.__token_details = self.request_token(token_params, **auth_options) - self._configure_client_id(self.__token_details.client_id) - - return self.__token_details - - def token_details_has_expired(self): - token_details = self.__token_details - if token_details is None: - return True - - if not self.__time_offset: - return False - - expires = token_details.expires - if expires is None: - return False - - timestamp = self._timestamp() - if self.__time_offset: - timestamp += self.__time_offset - - return expires < timestamp + token_details.TOKEN_EXPIRY_BUFFER - - def authorize(self, token_params: Optional[dict] = None, auth_options=None): - return self.__authorize_when_necessary(token_params, auth_options, force=True) - - def request_token(self, token_params: Optional[dict] = None, - # auth_options - key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None, - auth_url: Optional[str] = None, auth_method: Optional[str] = None, - auth_headers: Optional[dict] = None, auth_params: Optional[dict] = None, - query_time=None): - token_params = token_params or {} - token_params = dict(self.auth_options.default_token_params, - **token_params) - key_name = key_name or self.auth_options.key_name - key_secret = key_secret or self.auth_options.key_secret - - log.debug("Auth callback: %s" % auth_callback) - log.debug("Auth options: %s" % self.auth_options) - if query_time is None: - query_time = self.auth_options.query_time - query_time = bool(query_time) - auth_callback = auth_callback or self.auth_options.auth_callback - auth_url = auth_url or self.auth_options.auth_url - - auth_params = auth_params or self.auth_options.auth_params or {} - - auth_method = (auth_method or self.auth_options.auth_method).upper() - - auth_headers = auth_headers or self.auth_options.auth_headers or {} - - log.debug("Token Params: %s" % token_params) - if auth_callback: - log.debug("using token auth with authCallback") - try: - token_request = auth_callback(token_params) - except Exception as e: - raise AblyException("auth_callback raised an exception", 401, 40170, cause=e) - elif auth_url: - log.debug("using token auth with authUrl") - - token_request = self.token_request_from_auth_url( - auth_method, auth_url, token_params, auth_headers, auth_params) - elif key_name is not None and key_secret is not None: - token_request = self.create_token_request( - token_params, key_name=key_name, key_secret=key_secret, - query_time=query_time) - else: - msg = "Need a new token but auth_options does not include a way to request one" - log.exception(msg) - raise AblyAuthException(msg, 403, 40171) - if isinstance(token_request, TokenDetails): - return token_request - elif isinstance(token_request, dict) and 'issued' in token_request: - return TokenDetails.from_dict(token_request) - elif isinstance(token_request, dict): - try: - token_request = TokenRequest.from_json(token_request) - except TypeError as e: - msg = "Expected token request callback to call back with a token string, token request object, or \ - token details object" - raise AblyAuthException(msg, 401, 40170, cause=e) - elif isinstance(token_request, str): - if len(token_request) == 0: - raise AblyAuthException("Token string is empty", 401, 4017) - return TokenDetails(token=token_request) - elif token_request is None: - raise AblyAuthException("Token string was None", 401, 40170) - - token_path = "/keys/%s/requestToken" % token_request.key_name - - response = self.ably.http.post( - token_path, - headers=auth_headers, - body=token_request.to_dict(), - skip_auth=True - ) - - AblyException.raise_for_response(response) - response_dict = response.to_native() - log.debug("Token: %s" % str(response_dict.get("token"))) - return TokenDetails.from_dict(response_dict) - - def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None, - key_secret: Optional[str] = None, query_time=None): - token_params = token_params or {} - token_request = {} - - key_name = key_name or self.auth_options.key_name - key_secret = key_secret or self.auth_options.key_secret - if not key_name or not key_secret: - log.debug('key_name or key_secret blank') - raise AblyException("No key specified: no means to generate a token", 401, 40101) - - token_request['key_name'] = key_name - if token_params.get('timestamp'): - token_request['timestamp'] = token_params['timestamp'] - else: - if query_time is None: - query_time = self.auth_options.query_time - - if query_time: - if self.__time_offset is None: - server_time = self.ably.time() - local_time = self._timestamp() - self.__time_offset = server_time - local_time - token_request['timestamp'] = server_time - else: - local_time = self._timestamp() - token_request['timestamp'] = local_time + self.__time_offset - else: - token_request['timestamp'] = self._timestamp() - - token_request['timestamp'] = int(token_request['timestamp']) - - ttl = token_params.get('ttl') - if ttl is not None: - if isinstance(ttl, timedelta): - ttl = ttl.total_seconds() * 1000 - token_request['ttl'] = int(ttl) - - capability = token_params.get('capability') - if capability is not None: - token_request['capability'] = str(Capability(capability)) - - token_request["client_id"] = ( - token_params.get('client_id') or self.client_id) - - # Note: There is no expectation that the client - # specifies the nonce; this is done by the library - # However, this can be overridden by the client - # simply for testing purposes - token_request["nonce"] = token_params.get('nonce') or self._random_nonce() - - token_req = TokenRequest(**token_request) - - if token_params.get('mac') is None: - # Note: There is no expectation that the client - # specifies the mac; this is done by the library - # However, this can be overridden by the client - # simply for testing purposes. - token_req.sign_request(key_secret.encode('utf8')) - else: - token_req.mac = token_params['mac'] - - return token_req - - @property - def ably(self): - return self.__ably - - @property - def auth_mechanism(self): - return self.__auth_mechanism - - @property - def auth_options(self): - return self.__auth_options - - @property - def auth_params(self): - return self.__auth_params - - @property - def basic_credentials(self): - return self.__basic_credentials - - @property - def token_credentials(self): - if self.__token_details: - token = self.__token_details.token - token_key = base64.b64encode(token.encode('utf-8')) - return token_key.decode('ascii') - - @property - def token_details(self): - return self.__token_details - - @property - def client_id(self): - return self.__client_id - - @property - def time_offset(self): - return self.__time_offset - - def _configure_client_id(self, new_client_id): - log.debug("Auth._configure_client_id(): new client_id = %s", new_client_id) - original_client_id = self.client_id or self.auth_options.client_id - - # If new client ID from Ably is a wildcard, but preconfigured clientId is set, - # then keep the existing clientId - if original_client_id != '*' and new_client_id == '*': - self.__client_id_validated = True - self.__client_id = original_client_id - return - - # If client_id is defined and not a wildcard, prevent it changing, this is not supported - if original_client_id is not None and original_client_id != '*' and new_client_id != original_client_id: - raise IncompatibleClientIdException( - "Client ID is immutable once configured for a client. " - "Client ID cannot be changed to '{}'".format(new_client_id), 400, 40102) - - self.__client_id_validated = True - self.__client_id = new_client_id - - def can_assume_client_id(self, assumed_client_id): - original_client_id = self.client_id or self.auth_options.client_id - - if self.__client_id_validated: - return self.client_id == '*' or self.client_id == assumed_client_id - elif original_client_id is None or original_client_id == '*': - return True # client ID is unknown - else: - return original_client_id == assumed_client_id - - def _get_auth_headers(self): - if self.__auth_mechanism == AuthSync.Method.BASIC: - # RSA7e2 - if self.client_id: - return { - 'Authorization': 'Basic %s' % self.basic_credentials, - 'X-Ably-ClientId': base64.b64encode(self.client_id.encode('utf-8')) - } - return { - 'Authorization': 'Basic %s' % self.basic_credentials, - } - else: - self.__authorize_when_necessary() - return { - 'Authorization': 'Bearer %s' % self.token_credentials, - } - - def _timestamp(self): - """Returns the local time in milliseconds since the unix epoch""" - return int(time.time() * 1000) - - def _random_nonce(self): - return uuid.uuid4().hex[:16] - - def token_request_from_auth_url(self, method: str, url: str, token_params, - headers, auth_params): - body = None - params = None - if method == 'GET': - body = {} - params = dict(auth_params, **token_params) - elif method == 'POST': - if isinstance(auth_params, TokenDetails): - auth_params = auth_params.to_dict() - params = {} - body = dict(auth_params, **token_params) - - from ably.sync.http.http import Response - with httpx.Client(http2=True) as client: - resp = client.request(method=method, url=url, headers=headers, params=params, data=body) - response = Response(resp) - - AblyException.raise_for_response(response) - - content_type = response.response.headers.get('content-type') - - if not content_type: - raise AblyAuthException("auth_url response missing a content-type header", 401, 40170) - - is_json = "application/json" in content_type - is_text = "application/jwt" in content_type or "text/plain" in content_type - - if is_json: - token_request = response.to_native() - elif is_text: - token_request = response.text - else: - msg = 'auth_url responded with unacceptable content-type ' + content_type + \ - ', should be either text/plain, application/jwt or application/json', - raise AblyAuthException(msg, 401, 40170) - return token_request diff --git a/ably/sync/rest/channel.py b/ably/sync/rest/channel.py deleted file mode 100644 index 8804d46e..00000000 --- a/ably/sync/rest/channel.py +++ /dev/null @@ -1,229 +0,0 @@ -import base64 -from collections import OrderedDict -import logging -import json -import os -from typing import Iterator -from urllib import parse - -from methoddispatch import SingleDispatch, singledispatch -import msgpack - -from ably.sync.http.paginatedresult import PaginatedResultSync, format_params -from ably.sync.types.channeldetails import ChannelDetails -from ably.sync.types.message import Message, make_message_response_handler -from ably.sync.types.presence import Presence -from ably.sync.util.crypto import get_cipher -from ably.sync.util.exceptions import catch_all, IncompatibleClientIdException - -log = logging.getLogger(__name__) - - -class ChannelSync(SingleDispatch): - def __init__(self, ably, name, options): - self.__ably = ably - self.__name = name - self.__base_path = '/channels/%s/' % parse.quote_plus(name, safe=':') - self.__cipher = None - self.options = options - self.__presence = Presence(self) - - @catch_all - def history(self, direction=None, limit: int = None, start=None, end=None): - """Returns the history for this channel""" - params = format_params({}, direction=direction, start=start, end=end, limit=limit) - path = self.__base_path + 'messages' + params - - message_handler = make_message_response_handler(self.__cipher) - return PaginatedResultSync.paginated_query( - self.ably.http, url=path, response_processor=message_handler) - - def __publish_request_body(self, messages): - """ - Helper private method, separated from publish() to test RSL1j - """ - # Idempotent publishing - if self.ably.options.idempotent_rest_publishing: - # RSL1k1 - if all(message.id is None for message in messages): - base_id = base64.b64encode(os.urandom(12)).decode() - for serial, message in enumerate(messages): - message.id = '{}:{}'.format(base_id, serial) - - request_body_list = [] - for m in messages: - if m.client_id == '*': - raise IncompatibleClientIdException( - 'Wildcard client_id is reserved and cannot be used when publishing messages', - 400, 40012) - elif m.client_id is not None and not self.ably.auth.can_assume_client_id(m.client_id): - raise IncompatibleClientIdException( - 'Cannot publish with client_id \'{}\' as it is incompatible with the ' - 'current configured client_id \'{}\''.format(m.client_id, self.ably.auth.client_id), - 400, 40012) - - if self.cipher: - m.encrypt(self.__cipher) - - request_body_list.append(m) - - request_body = [ - message.as_dict(binary=self.ably.options.use_binary_protocol) - for message in request_body_list] - - if len(request_body) == 1: - request_body = request_body[0] - - return request_body - - @singledispatch - def _publish(self, arg, *args, **kwargs): - raise TypeError('Unexpected type %s' % type(arg)) - - @_publish.register(Message) - def publish_message(self, message, params=None, timeout=None): - return self.publish_messages([message], params, timeout=timeout) - - @_publish.register(list) - def publish_messages(self, messages, params=None, timeout=None): - request_body = self.__publish_request_body(messages) - if not self.ably.options.use_binary_protocol: - request_body = json.dumps(request_body, separators=(',', ':')) - else: - request_body = msgpack.packb(request_body, use_bin_type=True) - - path = self.__base_path + 'messages' - if params: - params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} - path += '?' + parse.urlencode(params) - return self.ably.http.post(path, body=request_body, timeout=timeout) - - @_publish.register(str) - def publish_name_data(self, name, data, timeout=None): - messages = [Message(name, data)] - return self.publish_messages(messages, timeout=timeout) - - def publish(self, *args, **kwargs): - """Publishes a message on this channel. - - :Parameters: - - `name`: the name for this message. - - `data`: the data for this message. - - `messages`: list of `Message` objects to be published. - - `message`: a single `Message` objet to be published - - :attention: You can publish using `name` and `data` OR `messages` OR - `message`, never all three. - """ - # For backwards compatibility - if len(args) == 0: - if len(kwargs) == 0: - return self.publish_name_data(None, None) - - if 'name' in kwargs or 'data' in kwargs: - name = kwargs.pop('name', None) - data = kwargs.pop('data', None) - return self.publish_name_data(name, data, **kwargs) - - if 'messages' in kwargs: - messages = kwargs.pop('messages') - return self.publish_messages(messages, **kwargs) - - return self._publish(*args, **kwargs) - - def status(self): - """Retrieves current channel active status with no. of publishers, subscribers, presence_members etc""" - - path = '/channels/%s' % self.name - response = self.ably.http.get(path) - obj = response.to_native() - return ChannelDetails.from_dict(obj) - - @property - def ably(self): - return self.__ably - - @property - def name(self): - return self.__name - - @property - def base_path(self): - return self.__base_path - - @property - def cipher(self): - return self.__cipher - - @property - def options(self): - return self.__options - - @property - def presence(self): - return self.__presence - - @options.setter - def options(self, options): - self.__options = options - - if options and 'cipher' in options: - cipher = options.get('cipher') - if cipher is not None: - cipher = get_cipher(cipher) - self.__cipher = cipher - - -class ChannelsSync: - def __init__(self, rest): - self.__ably = rest - self.__all: dict = OrderedDict() - - def get(self, name, **kwargs): - if isinstance(name, bytes): - name = name.decode('ascii') - - if name not in self.__all: - result = self.__all[name] = ChannelSync(self.__ably, name, kwargs) - else: - result = self.__all[name] - if len(kwargs) != 0: - result.options = kwargs - - return result - - def __getitem__(self, key): - return self.get(key) - - def __getattr__(self, name): - return self.get(name) - - def __contains__(self, item): - if isinstance(item, ChannelSync): - name = item.name - elif isinstance(item, bytes): - name = item.decode('ascii') - else: - name = item - - return name in self.__all - - def __iter__(self) -> Iterator[str]: - return iter(self.__all.values()) - - # RSN4 - def release(self, name: str): - """Releases a Channel object, deleting it, and enabling it to be garbage collected. - If the channel does not exist, nothing happens. - - It also removes any listeners associated with the channel. - - Parameters - ---------- - name: str - Channel name - """ - - if name not in self.__all: - return - del self.__all[name] diff --git a/ably/sync/rest/push.py b/ably/sync/rest/push.py deleted file mode 100644 index 3bb4de40..00000000 --- a/ably/sync/rest/push.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Optional -from ably.sync.http.paginatedresult import PaginatedResultSync, format_params -from ably.sync.types.device import DeviceDetails, device_details_response_processor -from ably.sync.types.channelsubscription import PushChannelSubscription, channel_subscriptions_response_processor -from ably.sync.types.channelsubscription import channels_response_processor - - -class PushSync: - - def __init__(self, ably): - self.__ably = ably - self.__admin = PushAdminSync(ably) - - @property - def admin(self): - return self.__admin - - -class PushAdminSync: - - def __init__(self, ably): - self.__ably = ably - self.__device_registrations = PushDeviceRegistrations(ably) - self.__channel_subscriptions = PushChannelSubscriptions(ably) - - @property - def ably(self): - return self.__ably - - @property - def device_registrations(self): - return self.__device_registrations - - @property - def channel_subscriptions(self): - return self.__channel_subscriptions - - def publish(self, recipient: dict, data: dict, timeout: Optional[float] = None): - """Publish a push notification to a single device. - - :Parameters: - - `recipient`: the recipient of the notification - - `data`: the data of the notification - """ - if not isinstance(recipient, dict): - raise TypeError('Unexpected %s recipient, expected a dict' % type(recipient)) - - if not isinstance(data, dict): - raise TypeError('Unexpected %s data, expected a dict' % type(data)) - - if not recipient: - raise ValueError('recipient is empty') - - if not data: - raise ValueError('data is empty') - - body = data.copy() - body.update({'recipient': recipient}) - self.ably.http.post('/push/publish', body=body, timeout=timeout) - - -class PushDeviceRegistrations: - - def __init__(self, ably): - self.__ably = ably - - @property - def ably(self): - return self.__ably - - def get(self, device_id: str): - """Returns a DeviceDetails object if the device id is found or results - in a not found error if the device cannot be found. - - :Parameters: - - `device_id`: the id of the device - """ - path = '/push/deviceRegistrations/%s' % device_id - response = self.ably.http.get(path) - obj = response.to_native() - return DeviceDetails.from_dict(obj) - - def list(self, **params): - """Returns a PaginatedResult object with the list of DeviceDetails - objects, filtered by the given parameters. - - :Parameters: - - `**params`: the parameters used to filter the list - """ - path = '/push/deviceRegistrations' + format_params(params) - return PaginatedResultSync.paginated_query( - self.ably.http, url=path, - response_processor=device_details_response_processor) - - def save(self, device: dict): - """Creates or updates the device. Returns a DeviceDetails object. - - :Parameters: - - `device`: a dictionary with the device information - """ - device_details = DeviceDetails.factory(device) - path = '/push/deviceRegistrations/%s' % device_details.id - body = device_details.as_dict() - response = self.ably.http.put(path, body=body) - obj = response.to_native() - return DeviceDetails.from_dict(obj) - - def remove(self, device_id: str): - """Deletes the registered device identified by the given device id. - - :Parameters: - - `device_id`: the id of the device - """ - path = '/push/deviceRegistrations/%s' % device_id - return self.ably.http.delete(path) - - def remove_where(self, **params): - """Deletes the registered devices identified by the given parameters. - - :Parameters: - - `**params`: the parameters that identify the devices to remove - """ - path = '/push/deviceRegistrations' + format_params(params) - return self.ably.http.delete(path) - - -class PushChannelSubscriptions: - - def __init__(self, ably): - self.__ably = ably - - @property - def ably(self): - return self.__ably - - def list(self, **params): - """Returns a PaginatedResult object with the list of - PushChannelSubscription objects, filtered by the given parameters. - - :Parameters: - - `**params`: the parameters used to filter the list - """ - path = '/push/channelSubscriptions' + format_params(params) - return PaginatedResultSync.paginated_query(self.ably.http, url=path, - response_processor=channel_subscriptions_response_processor) - - def list_channels(self, **params): - """Returns a PaginatedResult object with the list of - PushChannelSubscription objects, filtered by the given parameters. - - :Parameters: - - `**params`: the parameters used to filter the list - """ - path = '/push/channels' + format_params(params) - return PaginatedResultSync.paginated_query(self.ably.http, url=path, - response_processor=channels_response_processor) - - def save(self, subscription: dict): - """Creates or updates the subscription. Returns a - PushChannelSubscription object. - - :Parameters: - - `subscription`: a dictionary with the subscription information - """ - subscription = PushChannelSubscription.factory(subscription) - path = '/push/channelSubscriptions' - body = subscription.as_dict() - response = self.ably.http.post(path, body=body) - obj = response.to_native() - return PushChannelSubscription.from_dict(obj) - - def remove(self, subscription: dict): - """Deletes the given subscription. - - :Parameters: - - `subscription`: the subscription object to remove - """ - subscription = PushChannelSubscription.factory(subscription) - params = subscription.as_dict() - return self.remove_where(**params) - - def remove_where(self, **params): - """Deletes the subscriptions identified by the given parameters. - - :Parameters: - - `**params`: the parameters that identify the subscriptions to remove - """ - path = '/push/channelSubscriptions' + format_params(**params) - return self.ably.http.delete(path) diff --git a/ably/sync/rest/rest.py b/ably/sync/rest/rest.py deleted file mode 100644 index 5f0392e1..00000000 --- a/ably/sync/rest/rest.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging -from typing import Optional -from urllib.parse import urlencode - -from ably.sync.http.http import HttpSync -from ably.sync.http.paginatedresult import PaginatedResultSync, HttpPaginatedResponseSync -from ably.sync.http.paginatedresult import format_params -from ably.sync.rest.auth import AuthSync -from ably.sync.rest.channel import ChannelsSync -from ably.sync.rest.push import PushSync -from ably.sync.util.exceptions import AblyException, catch_all -from ably.sync.types.options import Options -from ably.sync.types.stats import stats_response_processor -from ably.sync.types.tokendetails import TokenDetails - -log = logging.getLogger(__name__) - - -class AblyRestSync: - """Ably Rest Client""" - - def __init__(self, key: Optional[str] = None, token: Optional[str] = None, - token_details: Optional[TokenDetails] = None, **kwargs): - """Create an AblyRest instance. - - :Parameters: - **Credentials** - - `key`: a valid key string - - **Or** - - `token`: a valid token string - - `token_details`: an instance of TokenDetails class - - **Optional Parameters** - - `client_id`: Undocumented - - `rest_host`: The host to connect to. Defaults to rest.ably.io - - `environment`: The environment to use. Defaults to 'production' - - `port`: The port to connect to. Defaults to 80 - - `tls_port`: The tls_port to connect to. Defaults to 443 - - `tls`: Specifies whether the client should use TLS. Defaults - to True - - `auth_token`: Undocumented - - `auth_callback`: Undocumented - - `auth_url`: Undocumented - - `keep_alive`: use persistent connections. Defaults to True - """ - if key is not None and ('key_name' in kwargs or 'key_secret' in kwargs): - raise ValueError("key and key_name or key_secret are mutually exclusive. " - "Provider either a key or key_name & key_secret") - if key is not None: - options = Options(key=key, **kwargs) - elif token is not None: - options = Options(auth_token=token, **kwargs) - elif token_details is not None: - if not isinstance(token_details, TokenDetails): - raise ValueError("token_details must be an instance of TokenDetails") - options = Options(token_details=token_details, **kwargs) - elif not ('auth_callback' in kwargs or 'auth_url' in kwargs or - # and don't have both key_name and key_secret - ('key_name' in kwargs and 'key_secret' in kwargs)): - raise ValueError("key is missing. Either an API key, token, or token auth method must be provided") - else: - options = Options(**kwargs) - - try: - self._is_realtime - except AttributeError: - self._is_realtime = False - - self.__http = HttpSync(self, options) - self.__auth = AuthSync(self, options) - self.__http.auth = self.__auth - - self.__channels = ChannelsSync(self) - self.__options = options - self.__push = PushSync(self) - - def __enter__(self): - return self - - @catch_all - def stats(self, direction: Optional[str] = None, start=None, end=None, params: Optional[dict] = None, - limit: Optional[int] = None, paginated=None, unit=None, timeout=None): - """Returns the stats for this application""" - formatted_params = format_params(params, direction=direction, start=start, end=end, limit=limit, unit=unit) - url = '/stats' + formatted_params - return PaginatedResultSync.paginated_query( - self.http, url=url, response_processor=stats_response_processor) - - @catch_all - def time(self, timeout: Optional[float] = None) -> float: - """Returns the current server time in ms since the unix epoch""" - r = self.http.get('/time', skip_auth=True, timeout=timeout) - AblyException.raise_for_response(r) - return r.to_native()[0] - - @property - def client_id(self) -> Optional[str]: - return self.options.client_id - - @property - def channels(self): - """Returns the channels container object""" - return self.__channels - - @property - def auth(self): - return self.__auth - - @property - def http(self): - return self.__http - - @property - def options(self): - return self.__options - - @property - def push(self): - return self.__push - - def request(self, method: str, path: str, version: str, params: - Optional[dict] = None, body=None, headers=None): - if version is None: - raise AblyException("No version parameter", 400, 40000) - - url = path - if params: - url += '?' + urlencode(params) - - def response_processor(response): - items = response.to_native() - if not items: - return [] - if type(items) is not list: - items = [items] - return items - - return HttpPaginatedResponseSync.paginated_query( - self.http, method, url, version=version, body=body, headers=headers, - response_processor=response_processor, - raise_on_error=False) - - def __exit__(self, *excinfo): - self.close() - - def close(self): - self.http.close() diff --git a/ably/sync/transport/__init__.py b/ably/sync/transport/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/transport/defaults.py b/ably/sync/transport/defaults.py deleted file mode 100644 index 7a732d9a..00000000 --- a/ably/sync/transport/defaults.py +++ /dev/null @@ -1,63 +0,0 @@ -class Defaults: - protocol_version = "2" - fallback_hosts = [ - "a.ably-realtime.com", - "b.ably-realtime.com", - "c.ably-realtime.com", - "d.ably-realtime.com", - "e.ably-realtime.com", - ] - - rest_host = "rest.ably.io" - realtime_host = "realtime.ably.io" # RTN2 - connectivity_check_url = "https://internet-up.ably-realtime.com/is-the-internet-up.txt" - environment = 'production' - - port = 80 - tls_port = 443 - connect_timeout = 15000 - disconnect_timeout = 10000 - suspended_timeout = 60000 - comet_recv_timeout = 90000 - comet_send_timeout = 10000 - realtime_request_timeout = 10000 - channel_retry_timeout = 15000 - disconnected_retry_timeout = 15000 - connection_state_ttl = 120000 - suspended_retry_timeout = 30000 - - transports = [] # ["web_socket", "comet"] - - http_max_retry_count = 3 - - fallback_retry_timeout = 600000 # 10min - - @staticmethod - def get_port(options): - if options.tls: - if options.tls_port: - return options.tls_port - else: - return Defaults.tls_port - else: - if options.port: - return options.port - else: - return Defaults.port - - @staticmethod - def get_scheme(options): - if options.tls: - return "https" - else: - return "http" - - @staticmethod - def get_environment_fallback_hosts(environment): - return [ - environment + "-a-fallback.ably-realtime.com", - environment + "-b-fallback.ably-realtime.com", - environment + "-c-fallback.ably-realtime.com", - environment + "-d-fallback.ably-realtime.com", - environment + "-e-fallback.ably-realtime.com", - ] diff --git a/ably/sync/transport/websockettransport.py b/ably/sync/transport/websockettransport.py deleted file mode 100644 index 2de820d3..00000000 --- a/ably/sync/transport/websockettransport.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -import asyncio -from enum import IntEnum -import json -import logging -import socket -import urllib.parse -from ably.sync.http.httputils import HttpUtils -from ably.sync.types.connectiondetails import ConnectionDetails -from ably.sync.util.eventemitter import EventEmitter -from ably.sync.util.exceptions import AblyException -from ably.sync.util.helper import Timer, unix_time_ms -from websockets.client import WebSocketClientProtocol, connect as ws_connect -from websockets.exceptions import ConnectionClosedOK, WebSocketException - -if TYPE_CHECKING: - from ably.sync.realtime.connection import ConnectionManager - -log = logging.getLogger(__name__) - - -class ProtocolMessageAction(IntEnum): - HEARTBEAT = 0 - CONNECTED = 4 - DISCONNECTED = 6 - CLOSE = 7 - CLOSED = 8 - ERROR = 9 - ATTACH = 10 - ATTACHED = 11 - DETACH = 12 - DETACHED = 13 - MESSAGE = 15 - AUTH = 17 - - -class WebSocketTransport(EventEmitter): - def __init__(self, connection_manager: ConnectionManager, host: str, params: dict): - self.websocket: WebSocketClientProtocol | None = None - self.read_loop: asyncio.Task | None = None - self.connect_task: asyncio.Task | None = None - self.ws_connect_task: asyncio.Task | None = None - self.connection_manager = connection_manager - self.options = self.connection_manager.options - self.is_connected = False - self.idle_timer = None - self.last_activity = None - self.max_idle_interval = None - self.is_disposed = False - self.host = host - self.params = params - super().__init__() - - def connect(self): - headers = HttpUtils.default_headers() - query_params = urllib.parse.urlencode(self.params) - ws_url = (f'wss://{self.host}?{query_params}') - log.info(f'connect(): attempting to connect to {ws_url}') - self.ws_connect_task = asyncio.create_task(self.ws_connect(ws_url, headers)) - self.ws_connect_task.add_done_callback(self.on_ws_connect_done) - - def on_ws_connect_done(self, task: asyncio.Task): - try: - exception = task.exception() - except asyncio.CancelledError as e: - exception = e - if exception is None or isinstance(exception, ConnectionClosedOK): - return - log.info( - f'WebSocketTransport.on_ws_connect_done(): exception = {exception}' - ) - - def ws_connect(self, ws_url, headers): - try: - with ws_connect(ws_url, extra_headers=headers) as websocket: - log.info(f'ws_connect(): connection established to {ws_url}') - self._emit('connected') - self.websocket = websocket - self.read_loop = self.connection_manager.options.loop.create_task(self.ws_read_loop()) - self.read_loop.add_done_callback(self.on_read_loop_done) - try: - self.read_loop - except WebSocketException as err: - if not self.is_disposed: - self.dispose() - self.connection_manager.deactivate_transport(err) - except (WebSocketException, socket.gaierror) as e: - exception = AblyException(f'Error opening websocket connection: {e}', 400, 40000) - log.exception(f'WebSocketTransport.ws_connect(): Error opening websocket connection: {exception}') - self._emit('failed', exception) - raise exception - - def on_protocol_message(self, msg): - self.on_activity() - log.debug(f'WebSocketTransport.on_protocol_message(): received protocol message: {msg}') - action = msg.get('action') - if action == ProtocolMessageAction.CONNECTED: - connection_id = msg.get('connectionId') - connection_details = ConnectionDetails.from_dict(msg.get('connectionDetails')) - - error = msg.get('error') - exception = None - if error: - exception = AblyException.from_dict(error) - - max_idle_interval = connection_details.max_idle_interval - if max_idle_interval: - self.max_idle_interval = max_idle_interval + self.options.realtime_request_timeout - self.on_activity() - self.is_connected = True - if self.host != self.options.get_realtime_host(): # RTN17e - self.options.fallback_realtime_host = self.host - self.connection_manager.on_connected(connection_details, connection_id, reason=exception) - elif action == ProtocolMessageAction.DISCONNECTED: - error = msg.get('error') - exception = None - if error is not None: - exception = AblyException.from_dict(error) - self.connection_manager.on_disconnected(exception) - elif action == ProtocolMessageAction.AUTH: - try: - self.connection_manager.ably.auth.authorize() - except Exception as exc: - log.exception(f"WebSocketTransport.on_protocol_message(): An exception \ - occurred during reauth: {exc}") - elif action == ProtocolMessageAction.CLOSED: - if self.ws_connect_task: - self.ws_connect_task.cancel() - self.connection_manager.on_closed() - elif action == ProtocolMessageAction.ERROR: - error = msg.get('error') - exception = AblyException.from_dict(error) - self.connection_manager.on_error(msg, exception) - elif action == ProtocolMessageAction.HEARTBEAT: - id = msg.get('id') - self.connection_manager.on_heartbeat(id) - elif action in ( - ProtocolMessageAction.ATTACHED, - ProtocolMessageAction.DETACHED, - ProtocolMessageAction.MESSAGE - ): - self.connection_manager.on_channel_message(msg) - - def ws_read_loop(self): - if not self.websocket: - raise AblyException('ws_read_loop started with no websocket', 500, 50000) - try: - for raw in self.websocket: - msg = json.loads(raw) - task = asyncio.create_task(self.on_protocol_message(msg)) - task.add_done_callback(self.on_protcol_message_handled) - except ConnectionClosedOK: - return - - def on_protcol_message_handled(self, task): - try: - exception = task.exception() - except Exception as e: - exception = e - if exception is not None: - log.exception(f"WebSocketTransport.on_protocol_message_handled(): uncaught exception: {exception}") - - def on_read_loop_done(self, task: asyncio.Task): - try: - exception = task.exception() - except asyncio.CancelledError as e: - exception = e - if isinstance(exception, ConnectionClosedOK): - return - - def dispose(self): - self.is_disposed = True - if self.read_loop: - self.read_loop.cancel() - if self.ws_connect_task: - self.ws_connect_task.cancel() - if self.idle_timer: - self.idle_timer.cancel() - if self.websocket: - try: - self.websocket.close() - except asyncio.CancelledError: - return - - def close(self): - self.send({'action': ProtocolMessageAction.CLOSE}) - - def send(self, message: dict): - if self.websocket is None: - raise Exception() - raw_msg = json.dumps(message) - log.info(f'WebSocketTransport.send(): sending {raw_msg}') - self.websocket.send(raw_msg) - - def set_idle_timer(self, timeout: float): - if not self.idle_timer: - self.idle_timer = Timer(timeout, self.on_idle_timer_expire) - - def on_idle_timer_expire(self): - self.idle_timer = None - since_last = unix_time_ms() - self.last_activity - time_remaining = self.max_idle_interval - since_last - msg = f"No activity seen from realtime in {since_last} ms; assuming connection has dropped" - if time_remaining <= 0: - log.error(msg) - self.disconnect(AblyException(msg, 408, 80003)) - else: - self.set_idle_timer(time_remaining + 100) - - def on_activity(self): - if not self.max_idle_interval: - return - self.last_activity = unix_time_ms() - self.set_idle_timer(self.max_idle_interval + 100) - - def disconnect(self, reason=None): - self.dispose() - self.connection_manager.deactivate_transport(reason) diff --git a/ably/sync/types/__init__.py b/ably/sync/types/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/types/authoptions.py b/ably/sync/types/authoptions.py deleted file mode 100644 index 77178f47..00000000 --- a/ably/sync/types/authoptions.py +++ /dev/null @@ -1,157 +0,0 @@ -from ably.sync.util.exceptions import AblyException - - -class AuthOptions: - def __init__(self, auth_callback=None, auth_url=None, auth_method='GET', - auth_token=None, auth_headers=None, auth_params=None, - key_name=None, key_secret=None, key=None, query_time=False, - token_details=None, use_token_auth=None, - default_token_params=None): - self.__auth_options = {} - self.auth_options['auth_callback'] = auth_callback - self.auth_options['auth_url'] = auth_url - self.auth_options['auth_method'] = auth_method - self.auth_options['auth_headers'] = auth_headers - self.auth_options['auth_params'] = auth_params - self.auth_options['query_time'] = query_time - self.auth_options['key_name'] = key_name - self.auth_options['key_secret'] = key_secret - self.set_key(key) - - self.__auth_token = auth_token - self.__token_details = token_details - self.__use_token_auth = use_token_auth - default_token_params = default_token_params or {} - default_token_params.pop('timestamp', None) - self.default_token_params = default_token_params - - def set_key(self, key): - if key is None: - return - - try: - key_name, key_secret = key.split(':') - self.auth_options['key_name'] = key_name - self.auth_options['key_secret'] = key_secret - except ValueError: - raise AblyException("key of not len 2 parameters: {0}" - .format(key.split(':')), - 401, 40101) - - def replace(self, auth_options): - if type(auth_options) is dict: - auth_options = dict(auth_options) - key = auth_options.pop('key', None) - self.auth_options = auth_options - self.set_key(key) - elif type(auth_options) is AuthOptions: - self.auth_options = dict(auth_options.auth_options) - else: - raise KeyError('Expected dict or AuthOptions') - - @property - def auth_options(self): - return self.__auth_options - - @auth_options.setter - def auth_options(self, value): - self.__auth_options = value - - @property - def auth_callback(self): - return self.auth_options['auth_callback'] - - @auth_callback.setter - def auth_callback(self, value): - self.auth_options['auth_callback'] = value - - @property - def auth_url(self): - return self.auth_options['auth_url'] - - @auth_url.setter - def auth_url(self, value): - self.auth_options['auth_url'] = value - - @property - def auth_method(self): - return self.auth_options['auth_method'] - - @auth_method.setter - def auth_method(self, value): - self.auth_options['auth_method'] = value.upper() - - @property - def key_name(self): - return self.auth_options['key_name'] - - @key_name.setter - def key_name(self, value): - self.auth_options['key_name'] = value - - @property - def key_secret(self): - return self.auth_options['key_secret'] - - @key_secret.setter - def key_secret(self, value): - self.auth_options['key_secret'] = value - - @property - def auth_token(self): - return self.__auth_token - - @auth_token.setter - def auth_token(self, value): - self.__auth_token = value - - @property - def auth_headers(self): - return self.auth_options['auth_headers'] - - @auth_headers.setter - def auth_headers(self, value): - self.auth_options['auth_headers'] = value - - @property - def auth_params(self): - return self.auth_options['auth_params'] - - @auth_params.setter - def auth_params(self, value): - self.auth_options['auth_params'] = value - - @property - def query_time(self): - return self.auth_options['query_time'] - - @query_time.setter - def query_time(self, value): - self.auth_options['query_time'] = value - - @property - def token_details(self): - return self.__token_details - - @token_details.setter - def token_details(self, value): - self.__token_details = value - - @property - def use_token_auth(self): - return self.__use_token_auth - - @use_token_auth.setter - def use_token_auth(self, value): - self.__use_token_auth = value - - @property - def default_token_params(self): - return self.__default_token_params - - @default_token_params.setter - def default_token_params(self, value): - self.__default_token_params = value - - def __str__(self): - return str(self.__dict__) diff --git a/ably/sync/types/capability.py b/ably/sync/types/capability.py deleted file mode 100644 index 5d209d7c..00000000 --- a/ably/sync/types/capability.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections.abc import MutableMapping -import json -import logging - - -log = logging.getLogger(__name__) - - -class Capability(MutableMapping): - def __init__(self, obj=None): - if obj is None: - obj = {} - self.__dict = dict(obj) - for k, v in obj.items(): - self[k] = v - - def __eq__(self, other): - if isinstance(other, Capability): - return Capability.c14n(self) == Capability.c14n(other) - return NotImplemented - - def __ne__(self, other): - if isinstance(other, Capability): - return Capability.c14n(self) != Capability.c14n(other) - return NotImplemented - - def __getitem__(self, key): - return self.__dict[key] - - def __iter__(self): - return iter(self.__dict) - - def __len__(self): - return len(self.__dict) - - def __contains__(self, key): - return key in self.__dict - - def __setitem__(self, key, value): - # validate that the value is a list of ops and that the key is a string - if not isinstance(key, str): - raise ValueError('Capability keys must be strings') - - if isinstance(value, str): - value = [value] - - operations = set() - for val in iter(value): - if not isinstance(val, str): - raise ValueError('Operations must be strings') - operations.add(val) - - self.__dict[key] = operations - - def __delitem__(self, key): - del self.__dict[key] - - def setdefault(self, key, default): - if key not in self: - self[key] = default - return self[key] - - def add_resource(self, resource, operations=None): - if operations is None: - operations = [] - if isinstance(operations, str): - operations = [operations] - self[resource] = list(operations) - - def add_operation_to_resource(self, operation, resource): - self.setdefault(resource, []).append(operation) - - def __str__(self): - return Capability.c14n(self) - - def to_dict(self): - return {k: sorted(v) for k, v in self.items()} - - @staticmethod - def c14n(capability): - sorted_ops = capability.to_dict() - return json.dumps(sorted_ops, sort_keys=True) diff --git a/ably/sync/types/channeldetails.py b/ably/sync/types/channeldetails.py deleted file mode 100644 index d959d487..00000000 --- a/ably/sync/types/channeldetails.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - - -class ChannelDetails: - - def __init__(self, channel_id, status): - self.__channel_id = channel_id - self.__status = status - - @property - def channel_id(self) -> str: - return self.__channel_id - - @property - def status(self) -> ChannelStatus: - return self.__status - - @staticmethod - def from_dict(obj): - kwargs = { - 'channel_id': obj.get("channelId"), - 'status': ChannelStatus.from_dict(obj.get("status")) - } - - return ChannelDetails(**kwargs) - - -class ChannelStatus: - - def __init__(self, is_active, occupancy): - self.__is_active = is_active - self.__occupancy = occupancy - - @property - def is_active(self) -> bool: - return self.__is_active - - @property - def occupancy(self) -> ChannelOccupancy: - return self.__occupancy - - @staticmethod - def from_dict(obj): - kwargs = { - 'is_active': obj.get("isActive"), - 'occupancy': ChannelOccupancy.from_dict(obj.get("occupancy")) - } - - return ChannelStatus(**kwargs) - - -class ChannelOccupancy: - - def __init__(self, metrics): - self.__metrics = metrics - - @property - def metrics(self) -> ChannelMetrics: - return self.__metrics - - @staticmethod - def from_dict(obj): - kwargs = { - 'metrics': ChannelMetrics.from_dict(obj.get("metrics")) - } - - return ChannelOccupancy(**kwargs) - - -class ChannelMetrics: - - def __init__(self, connections, presence_connections, presence_members, - presence_subscribers, publishers, subscribers): - self.__connections = connections - self.__presence_connections = presence_connections - self.__presence_members = presence_members - self.__presence_subscribers = presence_subscribers - self.__publishers = publishers - self.__subscribers = subscribers - - @property - def connections(self) -> int: - return self.__connections - - @property - def presence_connections(self) -> int: - return self.__presence_connections - - @property - def presence_members(self) -> int: - return self.__presence_members - - @property - def presence_subscribers(self) -> int: - return self.__presence_subscribers - - @property - def publishers(self) -> int: - return self.__publishers - - @property - def subscribers(self) -> int: - return self.__subscribers - - @staticmethod - def from_dict(obj): - kwargs = { - 'connections': obj.get("connections"), - 'presence_connections': obj.get("presenceConnections"), - 'presence_members': obj.get("presenceMembers"), - 'presence_subscribers': obj.get("presenceSubscribers"), - 'publishers': obj.get("publishers"), - 'subscribers': obj.get("subscribers") - } - - return ChannelMetrics(**kwargs) diff --git a/ably/sync/types/channelstate.py b/ably/sync/types/channelstate.py deleted file mode 100644 index 83352f7b..00000000 --- a/ably/sync/types/channelstate.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass -from typing import Optional -from enum import Enum -from ably.sync.util.exceptions import AblyException - - -class ChannelState(str, Enum): - INITIALIZED = 'initialized' - ATTACHING = 'attaching' - ATTACHED = 'attached' - DETACHING = 'detaching' - DETACHED = 'detached' - SUSPENDED = 'suspended' - FAILED = 'failed' - - -@dataclass -class ChannelStateChange: - previous: ChannelState - current: ChannelState - resumed: bool - reason: Optional[AblyException] = None diff --git a/ably/sync/types/channelsubscription.py b/ably/sync/types/channelsubscription.py deleted file mode 100644 index fec042ad..00000000 --- a/ably/sync/types/channelsubscription.py +++ /dev/null @@ -1,70 +0,0 @@ -from ably.sync.util import case - - -class PushChannelSubscription: - - def __init__(self, channel, device_id=None, client_id=None, app_id=None): - if not device_id and not client_id: - raise ValueError('missing expected device or client id') - - if device_id and client_id: - raise ValueError('both device and client id given, only one expected') - - self.__channel = channel - self.__device_id = device_id - self.__client_id = client_id - self.__app_id = app_id - - @property - def channel(self): - return self.__channel - - @property - def device_id(self): - return self.__device_id - - @property - def client_id(self): - return self.__client_id - - @property - def app_id(self): - return self.__app_id - - def as_dict(self): - keys = ['channel', 'device_id', 'client_id', 'app_id'] - - obj = {} - for key in keys: - value = getattr(self, key) - if value is not None: - key = case.snake_to_camel(key) - obj[key] = value - - return obj - - @classmethod - def from_dict(cls, obj): - obj = {case.camel_to_snake(key): value for key, value in obj.items()} - return cls(**obj) - - @classmethod - def from_array(cls, array): - return [cls.from_dict(d) for d in array] - - @classmethod - def factory(cls, subscription): - if isinstance(subscription, cls): - return subscription - - return cls.from_dict(subscription) - - -def channel_subscriptions_response_processor(response): - native = response.to_native() - return PushChannelSubscription.from_array(native) - - -def channels_response_processor(response): - native = response.to_native() - return native diff --git a/ably/sync/types/connectiondetails.py b/ably/sync/types/connectiondetails.py deleted file mode 100644 index a281daed..00000000 --- a/ably/sync/types/connectiondetails.py +++ /dev/null @@ -1,20 +0,0 @@ -from dataclasses import dataclass - - -@dataclass() -class ConnectionDetails: - connection_state_ttl: int - max_idle_interval: int - connection_key: str - - def __init__(self, connection_state_ttl: int, max_idle_interval: int, - connection_key: str, client_id: str): - self.connection_state_ttl = connection_state_ttl - self.max_idle_interval = max_idle_interval - self.connection_key = connection_key - self.client_id = client_id - - @staticmethod - def from_dict(json_dict: dict): - return ConnectionDetails(json_dict.get('connectionStateTtl'), json_dict.get('maxIdleInterval'), - json_dict.get('connectionKey'), json_dict.get('clientId')) diff --git a/ably/sync/types/connectionerrors.py b/ably/sync/types/connectionerrors.py deleted file mode 100644 index e63ddea9..00000000 --- a/ably/sync/types/connectionerrors.py +++ /dev/null @@ -1,30 +0,0 @@ -from ably.sync.types.connectionstate import ConnectionState -from ably.sync.util.exceptions import AblyException - -ConnectionErrors = { - ConnectionState.DISCONNECTED: AblyException( - 'Connection to server temporarily unavailable', - 400, - 80003, - ), - ConnectionState.SUSPENDED: AblyException( - 'Connection to server unavailable', - 400, - 80002, - ), - ConnectionState.FAILED: AblyException( - 'Connection failed or disconnected by server', - 400, - 80000, - ), - ConnectionState.CLOSING: AblyException( - 'Connection closing', - 400, - 80017, - ), - ConnectionState.CLOSED: AblyException( - 'Connection closed', - 400, - 80017, - ), -} diff --git a/ably/sync/types/connectionstate.py b/ably/sync/types/connectionstate.py deleted file mode 100644 index 24747466..00000000 --- a/ably/sync/types/connectionstate.py +++ /dev/null @@ -1,36 +0,0 @@ -from enum import Enum -from dataclasses import dataclass -from typing import Optional - -from ably.sync.util.exceptions import AblyException - - -class ConnectionState(str, Enum): - INITIALIZED = 'initialized' - CONNECTING = 'connecting' - CONNECTED = 'connected' - DISCONNECTED = 'disconnected' - CLOSING = 'closing' - CLOSED = 'closed' - FAILED = 'failed' - SUSPENDED = 'suspended' - - -class ConnectionEvent(str, Enum): - INITIALIZED = 'initialized' - CONNECTING = 'connecting' - CONNECTED = 'connected' - DISCONNECTED = 'disconnected' - CLOSING = 'closing' - CLOSED = 'closed' - FAILED = 'failed' - SUSPENDED = 'suspended' - UPDATE = 'update' - - -@dataclass -class ConnectionStateChange: - previous: ConnectionState - current: ConnectionState - event: ConnectionEvent - reason: Optional[AblyException] = None # RTN4f diff --git a/ably/sync/types/device.py b/ably/sync/types/device.py deleted file mode 100644 index 5cfefa5c..00000000 --- a/ably/sync/types/device.py +++ /dev/null @@ -1,116 +0,0 @@ -from ably.sync.util import case - - -DevicePushTransportType = {'fcm', 'gcm', 'apns', 'web'} -DevicePlatform = {'android', 'ios', 'browser'} -DeviceFormFactor = {'phone', 'tablet', 'desktop', 'tv', 'watch', 'car', 'embedded', 'other'} - - -class DeviceDetails: - - def __init__(self, id, client_id=None, form_factor=None, metadata=None, - platform=None, push=None, update_token=None, app_id=None, - device_identity_token=None, modified=None, device_secret=None): - - if push: - recipient = push.get('recipient') - if recipient: - transport_type = recipient.get('transportType') - if transport_type is not None and transport_type not in DevicePushTransportType: - raise ValueError('unexpected transport type {}'.format(transport_type)) - - if platform is not None and platform not in DevicePlatform: - raise ValueError('unexpected platform {}'.format(platform)) - - if form_factor is not None and form_factor not in DeviceFormFactor: - raise ValueError('unexpected form factor {}'.format(form_factor)) - - self.__id = id - self.__client_id = client_id - self.__form_factor = form_factor - self.__metadata = metadata - self.__platform = platform - self.__push = push - self.__update_token = update_token - self.__app_id = app_id - self.__device_identity_token = device_identity_token - self.__modified = modified - self.__device_secret = device_secret - - @property - def id(self): - return self.__id - - @property - def client_id(self): - return self.__client_id - - @property - def form_factor(self): - return self.__form_factor - - @property - def metadata(self): - return self.__metadata - - @property - def platform(self): - return self.__platform - - @property - def push(self): - return self.__push - - @property - def update_token(self): - return self.__update_token - - @property - def app_id(self): - return self.__app_id - - @property - def device_identity_token(self): - return self.__device_identity_token - - @property - def modified(self): - return self.__modified - - @property - def device_secret(self): - return self.__device_secret - - def as_dict(self): - keys = ['id', 'client_id', 'form_factor', 'metadata', 'platform', - 'push', 'update_token', 'app_id', 'device_identity_token', 'modified', 'device_secret'] - - obj = {} - for key in keys: - value = getattr(self, key) - if value is not None: - key = case.snake_to_camel(key) - obj[key] = value - - return obj - - @classmethod - def from_dict(cls, obj): - obj = {case.camel_to_snake(key): value for key, value in obj.items()} - return cls(**obj) - - @classmethod - def from_array(cls, array): - return [cls.from_dict(d) for d in array] - - @classmethod - def factory(cls, device): - if isinstance(device, cls): - return device - - return cls.from_dict(device) - - -def device_details_response_processor(response): - native = response.to_native() - return DeviceDetails.from_array(native) diff --git a/ably/sync/types/flags.py b/ably/sync/types/flags.py deleted file mode 100644 index 1666434c..00000000 --- a/ably/sync/types/flags.py +++ /dev/null @@ -1,19 +0,0 @@ -from enum import Enum - - -class Flag(int, Enum): - # Channel attach state flags - HAS_PRESENCE = 1 << 0 - HAS_BACKLOG = 1 << 1 - RESUMED = 1 << 2 - TRANSIENT = 1 << 4 - ATTACH_RESUME = 1 << 5 - # Channel mode flags - PRESENCE = 1 << 16 - PUBLISH = 1 << 17 - SUBSCRIBE = 1 << 18 - PRESENCE_SUBSCRIBE = 1 << 19 - - -def has_flag(message_flags: int, flag: Flag): - return message_flags & flag > 0 diff --git a/ably/sync/types/message.py b/ably/sync/types/message.py deleted file mode 100644 index 43c0a03c..00000000 --- a/ably/sync/types/message.py +++ /dev/null @@ -1,233 +0,0 @@ -import base64 -import json -import logging - -from ably.sync.types.typedbuffer import TypedBuffer -from ably.sync.types.mixins import EncodeDataMixin -from ably.sync.util.crypto import CipherData -from ably.sync.util.exceptions import AblyException - -log = logging.getLogger(__name__) - - -def to_text(value): - if value is None: - return value - elif isinstance(value, str): - return value - elif isinstance(value, bytes): - return value.decode() - else: - raise TypeError("expected string or bytes, not %s" % type(value)) - - -class Message(EncodeDataMixin): - - def __init__(self, - name=None, # TM2g - data=None, # TM2d - client_id=None, # TM2b - id=None, # TM2a - connection_id=None, # TM2c - connection_key=None, # TM2h - encoding='', # TM2e - timestamp=None, # TM2f - extras=None, # TM2i - ): - - super().__init__(encoding) - - self.__name = to_text(name) - self.__data = data - self.__client_id = to_text(client_id) - self.__id = to_text(id) - self.__connection_id = connection_id - self.__connection_key = connection_key - self.__timestamp = timestamp - self.__extras = extras - - def __eq__(self, other): - if isinstance(other, Message): - return (self.name == other.name - and self.data == other.data - and self.client_id == other.client_id - and self.timestamp == other.timestamp) - return NotImplemented - - def __ne__(self, other): - if isinstance(other, Message): - result = self.__eq__(other) - if result != NotImplemented: - return not result - return NotImplemented - - @property - def name(self): - return self.__name - - @property - def data(self): - return self.__data - - @property - def client_id(self): - return self.__client_id - - @property - def id(self): - return self.__id - - @id.setter - def id(self, value): - self.__id = value - - @property - def connection_id(self): - return self.__connection_id - - @property - def connection_key(self): - return self.__connection_key - - @property - def timestamp(self): - return self.__timestamp - - @property - def extras(self): - return self.__extras - - def encrypt(self, channel_cipher): - if isinstance(self.data, CipherData): - return - - elif isinstance(self.data, str): - self._encoding_array.append('utf-8') - - if isinstance(self.data, dict) or isinstance(self.data, list): - self._encoding_array.append('json') - self._encoding_array.append('utf-8') - - typed_data = TypedBuffer.from_obj(self.data) - if typed_data.buffer is None: - return True - encrypted_data = channel_cipher.encrypt(typed_data.buffer) - self.__data = CipherData(encrypted_data, typed_data.type, - cipher_type=channel_cipher.cipher_type) - - @staticmethod - def decrypt_data(channel_cipher, data): - if not isinstance(data, CipherData): - return - decrypted_data = channel_cipher.decrypt(data.buffer) - decrypted_typed_buffer = TypedBuffer(decrypted_data, data.type) - - return decrypted_typed_buffer.decode() - - def decrypt(self, channel_cipher): - decrypted_data = self.decrypt_data(channel_cipher, self.__data) - if decrypted_data is not None: - self.__data = decrypted_data - - def as_dict(self, binary=False): - data = self.data - data_type = None - encoding = self._encoding_array[:] - - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) - - request_body = { - 'name': self.name, - 'data': data, - 'timestamp': self.timestamp or None, - 'type': data_type or None, - 'clientId': self.client_id or None, - 'id': self.id or None, - 'connectionId': self.connection_id or None, - 'connectionKey': self.connection_key or None, - 'extras': self.extras, - } - - if encoding: - request_body['encoding'] = '/'.join(encoding).strip('/') - - # None values aren't included - request_body = {k: v for k, v in request_body.items() if v is not None} - - return request_body - - @staticmethod - def from_encoded(obj, cipher=None): - id = obj.get('id') - name = obj.get('name') - data = obj.get('data') - client_id = obj.get('clientId') - connection_id = obj.get('connectionId') - timestamp = obj.get('timestamp') - encoding = obj.get('encoding', '') - extras = obj.get('extras', None) - - decoded_data = Message.decode(data, encoding, cipher) - - return Message( - id=id, - name=name, - connection_id=connection_id, - client_id=client_id, - timestamp=timestamp, - extras=extras, - **decoded_data - ) - - @staticmethod - def __update_empty_fields(proto_msg: dict, msg: dict, msg_index: int): - if msg.get("id") is None or msg.get("id") == '': - msg['id'] = f"{proto_msg.get('id')}:{msg_index}" - if msg.get("connectionId") is None or msg.get("connectionId") == '': - msg['connectionId'] = proto_msg.get('connectionId') - if msg.get("timestamp") is None or msg.get("timestamp") == 0: - msg['timestamp'] = proto_msg.get('timestamp') - - @staticmethod - def update_inner_message_fields(proto_msg: dict): - messages: list[dict] = proto_msg.get('messages') - presence_messages: list[dict] = proto_msg.get('presence') - if messages is not None: - msg_index = 0 - for msg in messages: - Message.__update_empty_fields(proto_msg, msg, msg_index) - msg_index = msg_index + 1 - - if presence_messages is not None: - msg_index = 0 - for presence_msg in presence_messages: - Message.__update_empty_fields(proto_msg, presence_msg, msg_index) - msg_index = msg_index + 1 - - -def make_message_response_handler(cipher): - def encrypted_message_response_handler(response): - messages = response.to_native() - return Message.from_encoded_array(messages, cipher=cipher) - return encrypted_message_response_handler diff --git a/ably/sync/types/mixins.py b/ably/sync/types/mixins.py deleted file mode 100644 index d228611b..00000000 --- a/ably/sync/types/mixins.py +++ /dev/null @@ -1,75 +0,0 @@ -import base64 -import json -import logging - -from ably.sync.util.crypto import CipherData - - -log = logging.getLogger(__name__) - - -class EncodeDataMixin: - - def __init__(self, encoding): - self.encoding = encoding - - @property - def encoding(self): - return '/'.join(self._encoding_array).strip('/') - - @encoding.setter - def encoding(self, encoding): - if not encoding: - self._encoding_array = [] - else: - self._encoding_array = encoding.strip('/').split('/') - - @staticmethod - def decode(data, encoding='', cipher=None): - encoding = encoding.strip('/') - encoding_list = encoding.split('/') - - while encoding_list: - encoding = encoding_list.pop() - if not encoding: - # With messagepack, binary data is sent as bytes, without need - # to specify the base64 encoding. Here we coerce to bytearray, - # since that's what is used with the Json transport; though it - # can be argued that it should be the other way, and use always - # bytes, never bytearray. - if type(data) is bytes: - data = bytearray(data) - continue - if encoding == 'json': - if isinstance(data, bytes): - data = data.decode() - if isinstance(data, list) or isinstance(data, dict): - continue - data = json.loads(data) - elif encoding == 'base64' and isinstance(data, bytes): - data = bytearray(base64.b64decode(data)) - elif encoding == 'base64': - data = bytearray(base64.b64decode(data.encode('utf-8'))) - elif encoding.startswith('%s+' % CipherData.ENCODING_ID): - if not cipher: - log.error('Message cannot be decrypted as the channel is ' - 'not set up for encryption & decryption') - encoding_list.append(encoding) - break - data = cipher.decrypt(data) - elif encoding == 'utf-8' and isinstance(data, (bytes, bytearray)): - data = data.decode('utf-8') - elif encoding == 'utf-8': - pass - else: - log.error('Message cannot be decoded. ' - "Unsupported encoding type: '%s'" % encoding) - encoding_list.append(encoding) - break - - encoding = '/'.join(encoding_list) - return {'encoding': encoding, 'data': data} - - @classmethod - def from_encoded_array(cls, objs, cipher=None): - return [cls.from_encoded(obj, cipher=cipher) for obj in objs] diff --git a/ably/sync/types/options.py b/ably/sync/types/options.py deleted file mode 100644 index fb2dae2a..00000000 --- a/ably/sync/types/options.py +++ /dev/null @@ -1,330 +0,0 @@ -import random -import logging - -from ably.sync.transport.defaults import Defaults -from ably.sync.types.authoptions import AuthOptions - -log = logging.getLogger(__name__) - - -class Options(AuthOptions): - def __init__(self, client_id=None, log_level=0, tls=True, rest_host=None, realtime_host=None, port=0, - tls_port=0, use_binary_protocol=True, queue_messages=False, recover=False, environment=None, - http_open_timeout=None, http_request_timeout=None, realtime_request_timeout=None, - http_max_retry_count=None, http_max_retry_duration=None, fallback_hosts=None, - fallback_retry_timeout=None, disconnected_retry_timeout=None, idempotent_rest_publishing=None, - loop=None, auto_connect=True, suspended_retry_timeout=None, connectivity_check_url=None, - channel_retry_timeout=Defaults.channel_retry_timeout, add_request_ids=False, **kwargs): - - super().__init__(**kwargs) - - # TODO check these defaults - if fallback_retry_timeout is None: - fallback_retry_timeout = Defaults.fallback_retry_timeout - - if realtime_request_timeout is None: - realtime_request_timeout = Defaults.realtime_request_timeout - - if disconnected_retry_timeout is None: - disconnected_retry_timeout = Defaults.disconnected_retry_timeout - - if connectivity_check_url is None: - connectivity_check_url = Defaults.connectivity_check_url - - connection_state_ttl = Defaults.connection_state_ttl - - if suspended_retry_timeout is None: - suspended_retry_timeout = Defaults.suspended_retry_timeout - - if environment is not None and rest_host is not None: - raise ValueError('specify rest_host or environment, not both') - - if environment is not None and realtime_host is not None: - raise ValueError('specify realtime_host or environment, not both') - - if idempotent_rest_publishing is None: - from ably.sync import api_version - idempotent_rest_publishing = api_version >= '1.2' - - if environment is None: - environment = Defaults.environment - - self.__client_id = client_id - self.__log_level = log_level - self.__tls = tls - self.__rest_host = rest_host - self.__realtime_host = realtime_host - self.__port = port - self.__tls_port = tls_port - self.__use_binary_protocol = use_binary_protocol - self.__queue_messages = queue_messages - self.__recover = recover - self.__environment = environment - self.__http_open_timeout = http_open_timeout - self.__http_request_timeout = http_request_timeout - self.__realtime_request_timeout = realtime_request_timeout - self.__http_max_retry_count = http_max_retry_count - self.__http_max_retry_duration = http_max_retry_duration - self.__fallback_hosts = fallback_hosts - self.__fallback_retry_timeout = fallback_retry_timeout - self.__disconnected_retry_timeout = disconnected_retry_timeout - self.__channel_retry_timeout = channel_retry_timeout - self.__idempotent_rest_publishing = idempotent_rest_publishing - self.__loop = loop - self.__auto_connect = auto_connect - self.__connection_state_ttl = connection_state_ttl - self.__suspended_retry_timeout = suspended_retry_timeout - self.__connectivity_check_url = connectivity_check_url - self.__fallback_realtime_host = None - self.__add_request_ids = add_request_ids - - self.__rest_hosts = self.__get_rest_hosts() - self.__realtime_hosts = self.__get_realtime_hosts() - - @property - def client_id(self): - return self.__client_id - - @client_id.setter - def client_id(self, value): - self.__client_id = value - - @property - def log_level(self): - return self.__log_level - - @log_level.setter - def log_level(self, value): - self.__log_level = value - - @property - def tls(self): - return self.__tls - - @tls.setter - def tls(self, value): - self.__tls = value - - @property - def rest_host(self): - return self.__rest_host - - @rest_host.setter - def rest_host(self, value): - self.__rest_host = value - - # RTC1d - @property - def realtime_host(self): - return self.__realtime_host - - @realtime_host.setter - def realtime_host(self, value): - self.__realtime_host = value - - @property - def port(self): - return self.__port - - @port.setter - def port(self, value): - self.__port = value - - @property - def tls_port(self): - return self.__tls_port - - @tls_port.setter - def tls_port(self, value): - self.__tls_port = value - - @property - def use_binary_protocol(self): - return self.__use_binary_protocol - - @use_binary_protocol.setter - def use_binary_protocol(self, value): - self.__use_binary_protocol = value - - @property - def queue_messages(self): - return self.__queue_messages - - @queue_messages.setter - def queue_messages(self, value): - self.__queue_messages = value - - @property - def recover(self): - return self.__recover - - @recover.setter - def recover(self, value): - self.__recover = value - - @property - def environment(self): - return self.__environment - - @property - def http_open_timeout(self): - return self.__http_open_timeout - - @http_open_timeout.setter - def http_open_timeout(self, value): - self.__http_open_timeout = value - - @property - def http_request_timeout(self): - return self.__http_request_timeout - - @property - def realtime_request_timeout(self): - return self.__realtime_request_timeout - - @http_request_timeout.setter - def http_request_timeout(self, value): - self.__http_request_timeout = value - - @property - def http_max_retry_count(self): - return self.__http_max_retry_count - - @http_max_retry_count.setter - def http_max_retry_count(self, value): - self.__http_max_retry_count = value - - @property - def http_max_retry_duration(self): - return self.__http_max_retry_duration - - @http_max_retry_duration.setter - def http_max_retry_duration(self, value): - self.__http_max_retry_duration = value - - @property - def fallback_hosts(self): - return self.__fallback_hosts - - @property - def fallback_retry_timeout(self): - return self.__fallback_retry_timeout - - @property - def disconnected_retry_timeout(self): - return self.__disconnected_retry_timeout - - @property - def channel_retry_timeout(self): - return self.__channel_retry_timeout - - @property - def idempotent_rest_publishing(self): - return self.__idempotent_rest_publishing - - @property - def loop(self): - return self.__loop - - # RTC1b - @property - def auto_connect(self): - return self.__auto_connect - - @property - def connection_state_ttl(self): - return self.__connection_state_ttl - - @connection_state_ttl.setter - def connection_state_ttl(self, value): - self.__connection_state_ttl = value - - @property - def suspended_retry_timeout(self): - return self.__suspended_retry_timeout - - @property - def connectivity_check_url(self): - return self.__connectivity_check_url - - @property - def fallback_realtime_host(self): - return self.__fallback_realtime_host - - @fallback_realtime_host.setter - def fallback_realtime_host(self, value): - self.__fallback_realtime_host = value - - @property - def add_request_ids(self): - return self.__add_request_ids - - def __get_rest_hosts(self): - """ - Return the list of hosts as they should be tried. First comes the main - host. Then the fallback hosts in random order. - The returned list will have a length of up to http_max_retry_count. - """ - # Defaults - host = self.rest_host - if host is None: - host = Defaults.rest_host - - environment = self.environment - - http_max_retry_count = self.http_max_retry_count - if http_max_retry_count is None: - http_max_retry_count = Defaults.http_max_retry_count - - # Prepend environment - if environment != 'production': - host = '%s-%s' % (environment, host) - - # Fallback hosts - fallback_hosts = self.fallback_hosts - if fallback_hosts is None: - if host == Defaults.rest_host: - fallback_hosts = Defaults.fallback_hosts - elif environment != 'production': - fallback_hosts = Defaults.get_environment_fallback_hosts(environment) - else: - fallback_hosts = [] - - # Shuffle - fallback_hosts = list(fallback_hosts) - random.shuffle(fallback_hosts) - self.__fallback_hosts = fallback_hosts - - # First main host - hosts = [host] + fallback_hosts - hosts = hosts[:http_max_retry_count] - return hosts - - def __get_realtime_hosts(self): - if self.realtime_host is not None: - host = self.realtime_host - return [host] - elif self.environment != "production": - host = f'{self.environment}-{Defaults.realtime_host}' - else: - host = Defaults.realtime_host - - return [host] + self.__fallback_hosts - - def get_rest_hosts(self): - return self.__rest_hosts - - def get_rest_host(self): - return self.__rest_hosts[0] - - def get_realtime_hosts(self): - return self.__realtime_hosts - - def get_realtime_host(self): - return self.__realtime_hosts[0] - - def get_fallback_rest_hosts(self): - return self.__rest_hosts[1:] - - def get_fallback_realtime_hosts(self): - return self.__realtime_hosts[1:] diff --git a/ably/sync/types/presence.py b/ably/sync/types/presence.py deleted file mode 100644 index 35a6b498..00000000 --- a/ably/sync/types/presence.py +++ /dev/null @@ -1,174 +0,0 @@ -from datetime import datetime, timedelta -from urllib import parse - -from ably.sync.http.paginatedresult import PaginatedResultSync -from ably.sync.types.mixins import EncodeDataMixin - - -def _ms_since_epoch(dt): - epoch = datetime.utcfromtimestamp(0) - delta = dt - epoch - return int(delta.total_seconds() * 1000) - - -def _dt_from_ms_epoch(ms): - epoch = datetime.utcfromtimestamp(0) - return epoch + timedelta(milliseconds=ms) - - -class PresenceAction: - ABSENT = 0 - PRESENT = 1 - ENTER = 2 - LEAVE = 3 - UPDATE = 4 - - -class PresenceMessage(EncodeDataMixin): - - def __init__(self, - id=None, # TP3a - action=None, # TP3b - client_id=None, # TP3c - connection_id=None, # TP3d - data=None, # TP3e - encoding=None, # TP3f - timestamp=None, # TP3g - member_key=None, # TP3h (for RT only) - extras=None, # TP3i (functionality not specified) - ): - - self.__id = id - self.__action = action - self.__client_id = client_id - self.__connection_id = connection_id - self.__data = data - self.__encoding = encoding - self.__timestamp = timestamp - self.__member_key = member_key - self.__extras = extras - - @property - def id(self): - return self.__id - - @property - def action(self): - return self.__action - - @property - def client_id(self): - return self.__client_id - - @property - def connection_id(self): - return self.__connection_id - - @property - def data(self): - return self.__data - - @property - def encoding(self): - return self.__encoding - - @property - def timestamp(self): - return self.__timestamp - - @property - def member_key(self): - if self.connection_id and self.client_id: - return "%s:%s" % (self.connection_id, self.client_id) - - @property - def extras(self): - return self.__extras - - @staticmethod - def from_encoded(obj, cipher=None): - id = obj.get('id') - action = obj.get('action', PresenceAction.ENTER) - client_id = obj.get('clientId') - connection_id = obj.get('connectionId') - data = obj.get('data') - encoding = obj.get('encoding', '') - timestamp = obj.get('timestamp') - # member_key = obj.get('memberKey', None) - extras = obj.get('extras', None) - - if timestamp is not None: - timestamp = _dt_from_ms_epoch(timestamp) - - decoded_data = PresenceMessage.decode(data, encoding, cipher) - - return PresenceMessage( - id=id, - action=action, - client_id=client_id, - connection_id=connection_id, - timestamp=timestamp, - extras=extras, - **decoded_data - ) - - -class Presence: - def __init__(self, channel): - self.__base_path = '/channels/%s/' % parse.quote_plus(channel.name) - self.__binary = channel.ably.options.use_binary_protocol - self.__http = channel.ably.http - self.__cipher = channel.cipher - - def _path_with_qs(self, rel_path, qs=None): - path = rel_path - if qs: - path += ('?' + parse.urlencode(qs)) - return path - - def get(self, limit=None): - qs = {} - if limit: - if limit > 1000: - raise ValueError("The maximum allowed limit is 1000") - qs['limit'] = limit - path = self._path_with_qs(self.__base_path + 'presence', qs) - - presence_handler = make_presence_response_handler(self.__cipher) - return PaginatedResultSync.paginated_query( - self.__http, url=path, response_processor=presence_handler) - - def history(self, limit=None, direction=None, start=None, end=None): - qs = {} - if limit: - if limit > 1000: - raise ValueError("The maximum allowed limit is 1000") - qs['limit'] = limit - if direction: - qs['direction'] = direction - if start: - if isinstance(start, int): - qs['start'] = start - else: - qs['start'] = _ms_since_epoch(start) - if end: - if isinstance(end, int): - qs['end'] = end - else: - qs['end'] = _ms_since_epoch(end) - - if 'start' in qs and 'end' in qs and qs['start'] > qs['end']: - raise ValueError("'end' parameter has to be greater than or equal to 'start'") - - path = self._path_with_qs(self.__base_path + 'presence/history', qs) - - presence_handler = make_presence_response_handler(self.__cipher) - return PaginatedResultSync.paginated_query( - self.__http, url=path, response_processor=presence_handler) - - -def make_presence_response_handler(cipher): - def encrypted_presence_response_handler(response): - messages = response.to_native() - return PresenceMessage.from_encoded_array(messages, cipher=cipher) - return encrypted_presence_response_handler diff --git a/ably/sync/types/stats.py b/ably/sync/types/stats.py deleted file mode 100644 index ead5e548..00000000 --- a/ably/sync/types/stats.py +++ /dev/null @@ -1,67 +0,0 @@ -import logging -from datetime import datetime - -log = logging.getLogger(__name__) - - -class Stats: - - def __init__(self, entries=None, unit=None, interval_id=None, in_progress=None, app_id=None, schema=None): - self.interval_id = interval_id or '' - self.entries = entries - self.unit = unit - self.interval_time = interval_from_interval_id(self.interval_id) - self.in_progress = in_progress - self.app_id = app_id - self.schema = schema - - @classmethod - def from_dict(cls, stats_dict): - stats_dict = stats_dict or {} - - kwargs = { - "entries": stats_dict.get("entries"), - "unit": stats_dict.get("unit"), - "interval_id": stats_dict.get("intervalId"), - "in_progress": stats_dict.get("inProgress"), - "app_id": stats_dict.get("appId"), - "schema": stats_dict.get("schema"), - } - - return cls(**kwargs) - - @classmethod - def from_array(cls, stats_array): - return [cls.from_dict(d) for d in stats_array] - - @staticmethod - def to_interval_id(date_time, granularity): - return date_time.strftime(INTERVALS_FMT[granularity]) - - -def stats_response_processor(response): - stats_array = response.to_native() - return Stats.from_array(stats_array) - - -INTERVALS_FMT = { - 'minute': '%Y-%m-%d:%H:%M', - 'hour': '%Y-%m-%d:%H', - 'day': '%Y-%m-%d', - 'month': '%Y-%m', -} - - -def granularity_from_interval_id(interval_id): - for key, value in INTERVALS_FMT.items(): - try: - datetime.strptime(interval_id, value) - return key - except ValueError: - pass - raise ValueError("Unsupported intervalId") - - -def interval_from_interval_id(interval_id): - granularity = granularity_from_interval_id(interval_id) - return datetime.strptime(interval_id, INTERVALS_FMT[granularity]) diff --git a/ably/sync/types/tokendetails.py b/ably/sync/types/tokendetails.py deleted file mode 100644 index 4a898a5b..00000000 --- a/ably/sync/types/tokendetails.py +++ /dev/null @@ -1,97 +0,0 @@ -import json -import time - -from ably.sync.types.capability import Capability - - -class TokenDetails: - - DEFAULTS = {'ttl': 60 * 60 * 1000} - # Buffer in milliseconds before a token is considered unusable - # For example, if buffer is 10000ms, the token can no longer be used for - # new requests 9000ms before it expires - TOKEN_EXPIRY_BUFFER = 15 * 1000 - - def __init__(self, token=None, expires=None, issued=0, - capability=None, client_id=None): - if expires is None: - self.__expires = time.time() * 1000 + TokenDetails.DEFAULTS['ttl'] - else: - self.__expires = expires - self.__token = token - self.__issued = issued - if capability and isinstance(capability, str): - try: - self.__capability = Capability(json.loads(capability)) - except json.JSONDecodeError: - self.__capability = Capability(json.loads(capability.replace("'", '"'))) - else: - self.__capability = Capability(capability or {}) - self.__client_id = client_id - - @property - def token(self): - return self.__token - - @property - def expires(self): - return self.__expires - - @property - def issued(self): - return self.__issued - - @property - def capability(self): - return self.__capability - - @property - def client_id(self): - return self.__client_id - - def to_dict(self): - return { - 'expires': self.expires, - 'token': self.token, - 'issued': self.issued, - 'capability': self.capability.to_dict(), - 'clientId': self.client_id, - } - - @staticmethod - def from_dict(obj): - kwargs = { - 'token': obj.get("token"), - 'capability': obj.get("capability"), - 'client_id': obj.get("clientId") - } - expires = obj.get("expires") - kwargs['expires'] = expires if expires is None else int(expires) - issued = obj.get("issued") - kwargs['issued'] = issued if issued is None else int(issued) - - return TokenDetails(**kwargs) - - @staticmethod - def from_json(data): - if isinstance(data, str): - data = json.loads(data) - - mapping = { - 'clientId': 'client_id', - } - for name in data: - py_name = mapping.get(name) - if py_name: - data[py_name] = data.pop(name) - - return TokenDetails(**data) - - def __eq__(self, other): - if isinstance(other, TokenDetails): - return (self.expires == other.expires - and self.token == other.token - and self.issued == other.issued - and self.capability == other.capability - and self.client_id == other.client_id) - return NotImplemented diff --git a/ably/sync/types/tokenrequest.py b/ably/sync/types/tokenrequest.py deleted file mode 100644 index d10a5eb3..00000000 --- a/ably/sync/types/tokenrequest.py +++ /dev/null @@ -1,107 +0,0 @@ -import base64 -import hashlib -import hmac -import json - - -class TokenRequest: - - def __init__(self, key_name=None, client_id=None, nonce=None, mac=None, - capability=None, ttl=None, timestamp=None): - self.__key_name = key_name - self.__client_id = client_id - self.__nonce = nonce - self.__mac = mac - self.__capability = capability - self.__ttl = ttl - self.__timestamp = timestamp - - def sign_request(self, key_secret): - sign_text = "\n".join([str(x) for x in [ - self.key_name or "", - self.ttl or "", - self.capability or "", - self.client_id or "", - "%d" % (self.timestamp or 0), - self.nonce or "", - "", # to get the trailing new line - ]]) - try: - key_secret = key_secret.encode('utf8') - except AttributeError: - pass - try: - sign_text = sign_text.encode('utf8') - except AttributeError: - pass - mac = hmac.new(key_secret, sign_text, hashlib.sha256).digest() - self.mac = base64.b64encode(mac).decode('utf8') - - def to_dict(self): - return { - 'keyName': self.key_name, - 'clientId': self.client_id, - 'ttl': self.ttl, - 'nonce': self.nonce, - 'capability': self.capability, - 'timestamp': self.timestamp, - 'mac': self.mac - } - - @staticmethod - def from_json(data): - if isinstance(data, str): - data = json.loads(data) - - mapping = { - 'keyName': 'key_name', - 'clientId': 'client_id', - } - for name, py_name in mapping.items(): - if name in data: - data[py_name] = data.pop(name) - - return TokenRequest(**data) - - def __eq__(self, other): - if isinstance(other, TokenRequest): - return (self.key_name == other.key_name - and self.client_id == other.client_id - and self.nonce == other.nonce - and self.mac == other.mac - and self.capability == other.capability - and self.ttl == other.ttl - and self.timestamp == other.timestamp) - return NotImplemented - - @property - def key_name(self): - return self.__key_name - - @property - def client_id(self): - return self.__client_id - - @property - def nonce(self): - return self.__nonce - - @property - def mac(self): - return self.__mac - - @mac.setter - def mac(self, mac): - self.__mac = mac - - @property - def capability(self): - return self.__capability - - @property - def ttl(self): - return self.__ttl - - @property - def timestamp(self): - return self.__timestamp diff --git a/ably/sync/types/typedbuffer.py b/ably/sync/types/typedbuffer.py deleted file mode 100644 index 56adcd88..00000000 --- a/ably/sync/types/typedbuffer.py +++ /dev/null @@ -1,104 +0,0 @@ -# This functionality is depreceated and will be removed -# Message Pack is the replacement for all binary data messages - -import json -import struct - - -class DataType: - NONE = 0 - TRUE = 1 - FALSE = 2 - INT32 = 3 - INT64 = 4 - DOUBLE = 5 - STRING = 6 - BUFFER = 7 - JSONARRAY = 8 - JSONOBJECT = 9 - - -class Limits: - INT32_MAX = 2 ** 31 - INT32_MIN = -(2 ** 31 + 1) - INT64_MAX = 2 ** 63 - INT64_MIN = - (2 ** 63 + 1) - - -_decoders = {DataType.TRUE: lambda b: True, - DataType.FALSE: lambda b: False, - DataType.INT32: lambda b: struct.unpack('>i', b)[0], - DataType.INT64: lambda b: struct.unpack('>q', b)[0], - DataType.DOUBLE: lambda b: struct.unpack('>d', b)[0], - DataType.STRING: lambda b: b.decode('utf-8'), - DataType.BUFFER: lambda b: b, - DataType.JSONARRAY: lambda b: json.loads(b.decode('utf-8')), - DataType.JSONOBJECT: lambda b: json.loads(b.decode('utf-8'))} - - -class TypedBuffer: - def __init__(self, buffer, type): - self.__buffer = buffer - self.__type = type - - def __eq__(self, other): - if isinstance(other, TypedBuffer): - return self.buffer == other.buffer and self.type == other.type - return NotImplemented - - def __ne__(self, other): - if isinstance(other, TypedBuffer): - result = self.__eq__(other) - if result != NotImplemented: - return not result - return NotImplemented - - @staticmethod - def from_obj(obj): - if isinstance(obj, TypedBuffer): - return obj - elif isinstance(obj, (bytes, bytearray)): - data_type = DataType.BUFFER - buffer = obj - elif isinstance(obj, str): - data_type = DataType.STRING - buffer = obj.encode('utf-8') - elif isinstance(obj, bool): - data_type = DataType.TRUE if obj else DataType.FALSE - buffer = None - elif isinstance(obj, int): - if Limits.INT32_MIN <= obj <= Limits.INT32_MAX: - data_type = DataType.INT32 - buffer = struct.pack('>i', obj) - elif Limits.INT64_MIN <= obj <= Limits.INT64_MAX: - data_type = DataType.INT64 - buffer = struct.pack('>q', obj) - else: - raise ValueError('Number too large %d' % obj) - elif isinstance(obj, float): - data_type = DataType.DOUBLE - buffer = struct.pack('>d', obj) - elif isinstance(obj, list): - data_type = DataType.JSONARRAY - buffer = json.dumps(obj, separators=(',', ':')).encode('utf-8') - elif isinstance(obj, dict): - data_type = DataType.JSONOBJECT - buffer = json.dumps(obj, separators=(',', ':')).encode('utf-8') - else: - raise TypeError('Unexpected object type %s' % type(obj)) - - return TypedBuffer(buffer, data_type) - - @property - def buffer(self): - return self.__buffer - - @property - def type(self): - return self.__type - - def decode(self): - decoder = _decoders.get(self.type) - if decoder is not None: - return decoder(self.buffer) - raise ValueError('Unsupported data type %s' % self.type) diff --git a/ably/sync/util/__init__.py b/ably/sync/util/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ably/sync/util/case.py b/ably/sync/util/case.py deleted file mode 100644 index 3b18c49e..00000000 --- a/ably/sync/util/case.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - - -first_cap_re = re.compile('(.)([A-Z][a-z]+)') -all_cap_re = re.compile('([a-z0-9])([A-Z])') - - -def camel_to_snake(name): - s1 = first_cap_re.sub(r'\1_\2', name) - return all_cap_re.sub(r'\1_\2', s1).lower() - - -def snake_to_camel(name): - name = name.split('_') - for i in range(1, len(name)): - name[i] = name[i].title() - - return ''.join(name) diff --git a/ably/sync/util/crypto.py b/ably/sync/util/crypto.py deleted file mode 100644 index bf1a9a35..00000000 --- a/ably/sync/util/crypto.py +++ /dev/null @@ -1,179 +0,0 @@ -import base64 -import logging - -try: - from Crypto.Cipher import AES - from Crypto import Random -except ImportError: - from .nocrypto import AES, Random - -from ably.sync.types.typedbuffer import TypedBuffer -from ably.sync.util.exceptions import AblyException - -log = logging.getLogger(__name__) - - -class CipherParams: - def __init__(self, algorithm='AES', mode='CBC', secret_key=None, iv=None): - self.__algorithm = algorithm.upper() - self.__secret_key = secret_key - self.__key_length = len(secret_key) * 8 if secret_key is not None else 128 - self.__mode = mode.upper() - self.__iv = iv - - @property - def algorithm(self): - return self.__algorithm - - @property - def secret_key(self): - return self.__secret_key - - @property - def iv(self): - return self.__iv - - @property - def key_length(self): - return self.__key_length - - @property - def mode(self): - return self.__mode - - -class CbcChannelCipher: - def __init__(self, cipher_params): - self.__secret_key = (cipher_params.secret_key or - self.__random(cipher_params.key_length / 8)) - if isinstance(self.__secret_key, str): - self.__secret_key = self.__secret_key.encode() - self.__iv = cipher_params.iv or self.__random(16) - self.__block_size = len(self.__iv) - if cipher_params.algorithm != 'AES': - raise NotImplementedError('Only AES algorithm is supported') - self.__algorithm = cipher_params.algorithm - if cipher_params.mode != 'CBC': - raise NotImplementedError('Only CBC mode is supported') - self.__mode = cipher_params.mode - self.__key_length = cipher_params.key_length - self.__encryptor = AES.new(self.__secret_key, AES.MODE_CBC, self.__iv) - - def __pad(self, data): - padding_size = self.__block_size - (len(data) % self.__block_size) - - padding_char = bytes((padding_size,)) - padded = data + padding_char * padding_size - - return padded - - def __unpad(self, data): - padding_size = data[-1] - - if padding_size > len(data): - # Too short - raise AblyException('invalid-padding', 0, 0) - - if padding_size == 0: - # Missing padding - raise AblyException('invalid-padding', 0, 0) - - for i in range(padding_size): - # Invalid padding bytes - if padding_size != data[-i - 1]: - raise AblyException('invalid-padding', 0, 0) - - return data[:-padding_size] - - def __random(self, length): - rndfile = Random.new() - return rndfile.read(length) - - def encrypt(self, plaintext): - if isinstance(plaintext, bytearray): - plaintext = bytes(plaintext) - padded_plaintext = self.__pad(plaintext) - encrypted = self.__iv + self.__encryptor.encrypt(padded_plaintext) - self.__iv = encrypted[-self.__block_size:] - return encrypted - - def decrypt(self, ciphertext): - if isinstance(ciphertext, bytearray): - ciphertext = bytes(ciphertext) - iv = ciphertext[:self.__block_size] - ciphertext = ciphertext[self.__block_size:] - decryptor = AES.new(self.__secret_key, AES.MODE_CBC, iv) - decrypted = decryptor.decrypt(ciphertext) - return bytearray(self.__unpad(decrypted)) - - @property - def secret_key(self): - return self.__secret_key - - @property - def iv(self): - return self.__iv - - @property - def cipher_type(self): - return ("%s-%s-%s" % (self.__algorithm, self.__key_length, - self.__mode)).lower() - - -class CipherData(TypedBuffer): - ENCODING_ID = 'cipher' - - def __init__(self, buffer, type, cipher_type=None, **kwargs): - self.__cipher_type = cipher_type - super().__init__(buffer, type, **kwargs) - - @property - def encoding_str(self): - return self.ENCODING_ID + '+' + self.__cipher_type - - -DEFAULT_KEYLENGTH = 256 -DEFAULT_BLOCKLENGTH = 16 - - -def generate_random_key(length=DEFAULT_KEYLENGTH): - rndfile = Random.new() - return rndfile.read(length // 8) - - -def get_default_params(params=None): - if type(params) in [str, bytes]: - raise ValueError("Calling get_default_params with a key directly is deprecated, it expects a params dict") - - key = params.get('key') - algorithm = params.get('algorithm') or 'AES' - iv = params.get('iv') or generate_random_key(DEFAULT_BLOCKLENGTH * 8) - mode = params.get('mode') or 'CBC' - - if not key: - raise ValueError("Crypto.get_default_params: a key is required") - - if type(key) == str: - key = base64.b64decode(key) - - cipher_params = CipherParams(algorithm=algorithm, secret_key=key, iv=iv, mode=mode) - validate_cipher_params(cipher_params) - return cipher_params - - -def get_cipher(params): - if isinstance(params, CipherParams): - cipher_params = params - else: - cipher_params = get_default_params(params) - return CbcChannelCipher(cipher_params) - - -def validate_cipher_params(cipher_params): - if cipher_params.algorithm == 'AES' and cipher_params.mode == 'CBC': - key_length = cipher_params.key_length - if key_length == 128 or key_length == 256: - return - raise ValueError( - 'Unsupported key length %s for aes-cbc encryption. Encryption key must be 128 or 256 bits' - ' (16 or 32 ASCII characters)' % key_length) diff --git a/ably/sync/util/eventemitter.py b/ably/sync/util/eventemitter.py deleted file mode 100644 index 47c139db..00000000 --- a/ably/sync/util/eventemitter.py +++ /dev/null @@ -1,185 +0,0 @@ -import asyncio -import logging -from pyee.asyncio import AsyncIOEventEmitter - -from ably.sync.util.helper import is_callable_or_coroutine - -# pyee's event emitter doesn't support attaching a listener to all events -# so to patch it, we create a wrapper which uses two event emitters, one -# is used to listen to all events and this arbitrary string is the event name -# used to emit all events on that listener -_all_event = 'all' - -log = logging.getLogger(__name__) - - -def _is_named_event_args(*args): - return len(args) == 2 and is_callable_or_coroutine(args[1]) - - -def _is_all_event_args(*args): - return len(args) == 1 and is_callable_or_coroutine(args[0]) - - -class EventEmitter: - """ - A generic interface for event registration and delivery used in a number of the types in the Realtime client - library. For example, the Connection object emits events for connection state using the EventEmitter pattern. - - Methods - ------- - on(*args) - Attach to channel - once(*args) - Detach from channel - off() - Subscribe to messages on a channel - """ - - def __init__(self): - self.__named_event_emitter = AsyncIOEventEmitter() - self.__all_event_emitter = AsyncIOEventEmitter() - self.__wrapped_listeners = {} - - def on(self, *args): - """ - Registers the provided listener for the specified event, if provided, and otherwise for all events. - If on() is called more than once with the same listener and event, the listener is added multiple times to - its listener registry. Therefore, as an example, assuming the same listener is registered twice using - on(), and an event is emitted once, the listener would be invoked twice. - - Parameters - ---------- - name : str - The named event to listen for. - listener : callable - The event listener. - """ - if _is_all_event_args(*args): - event = _all_event - listener = args[0] - emitter = self.__all_event_emitter - # self.__all_event_emitter.add_listener(_all_event, args[0]) - elif _is_named_event_args(*args): - event = args[0] - listener = args[1] - emitter = self.__named_event_emitter - # self.__named_event_emitter.add_listener(args[0], args[1]) - else: - raise ValueError("EventEmitter.on(): invalid args") - - if asyncio.iscoroutinefunction(listener): - def wrapped_listener(*args, **kwargs): - try: - listener(*args, **kwargs) - except Exception as err: - log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') - else: - def wrapped_listener(*args, **kwargs): - try: - listener(*args, **kwargs) - except Exception as err: - log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') - - self.__wrapped_listeners[listener] = wrapped_listener - - emitter.add_listener(event, wrapped_listener) - - def once(self, *args): - """ - Registers the provided listener for the first event that is emitted. If once() is called more than once - with the same listener, the listener is added multiple times to its listener registry. Therefore, as an - example, assuming the same listener is registered twice using once(), and an event is emitted once, the - listener would be invoked twice. However, all subsequent events emitted would not invoke the listener as - once() ensures that each registration is only invoked once. - - Parameters - ---------- - name : str - The named event to listen for. - listener : callable - The event listener. - """ - if _is_all_event_args(*args): - event = _all_event - listener = args[0] - emitter = self.__all_event_emitter - # self.__all_event_emitter.add_listener(_all_event, args[0]) - elif _is_named_event_args(*args): - event = args[0] - listener = args[1] - emitter = self.__named_event_emitter - # self.__named_event_emitter.add_listener(args[0], args[1]) - else: - raise ValueError("EventEmitter.on(): invalid args") - - if asyncio.iscoroutinefunction(listener): - def wrapped_listener(*args, **kwargs): - try: - listener(*args, **kwargs) - except Exception as err: - log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') - else: - def wrapped_listener(*args, **kwargs): - try: - listener(*args, **kwargs) - except Exception as err: - log.exception(f'EventEmitter.emit(): uncaught listener exception: {err}') - - self.__wrapped_listeners[listener] = wrapped_listener - - emitter.once(event, wrapped_listener) - - def off(self, *args): - """ - Removes all registrations that match both the specified listener and, if provided, the specified event. - If called with no arguments, deregisters all registrations, for all events and listeners. - - Parameters - ---------- - name : str - The named event to listen for. - listener : callable - The event listener. - """ - if len(args) == 0: - self.__all_event_emitter.remove_all_listeners() - self.__named_event_emitter.remove_all_listeners() - return - elif _is_all_event_args(*args): - event = _all_event - listener = args[0] - emitter = self.__all_event_emitter - elif _is_named_event_args(*args): - event = args[0] - listener = args[1] - emitter = self.__named_event_emitter - else: - raise ValueError("EventEmitter.once(): invalid args") - - wrapped_listener = self.__wrapped_listeners.get(listener) - - if wrapped_listener is None: - return - - emitter.remove_listener(event, wrapped_listener) - self.__wrapped_listeners[listener] = None - - def once_async(self, state=None): - future = asyncio.Future() - - def on_state_change(*args): - future.set_result(*args) - - if state is not None: - self.once(state, on_state_change) - else: - self.once(on_state_change) - - state_change = future - - return state_change - - def _emit(self, *args): - self.__named_event_emitter.emit(*args) - self.__all_event_emitter.emit(_all_event, *args[1:]) diff --git a/ably/sync/util/exceptions.py b/ably/sync/util/exceptions.py deleted file mode 100644 index 090cf3d8..00000000 --- a/ably/sync/util/exceptions.py +++ /dev/null @@ -1,92 +0,0 @@ -import functools -import logging - - -log = logging.getLogger(__name__) - - -class AblyException(Exception): - def __new__(cls, message, status_code, code, cause=None): - if cls == AblyException and status_code == 401: - return AblyAuthException(message, status_code, code, cause) - return super().__new__(cls, message, status_code, code, cause) - - def __init__(self, message, status_code, code, cause=None): - super().__init__() - self.message = message - self.code = code - self.status_code = status_code - self.cause = cause - - def __str__(self): - str = '%s %s %s' % (self.code, self.status_code, self.message) - if self.cause is not None: - str += ' (cause: %s)' % self.cause - return str - - @property - def is_server_error(self): - return 500 <= self.status_code <= 599 - - @staticmethod - def raise_for_response(response): - if 200 <= response.status_code < 300: - # Valid response - return - - try: - json_response = response.json() - except Exception: - log.debug("Response not json: %d %s", - response.status_code, - response.text) - raise AblyException(message=response.text, - status_code=response.status_code, - code=response.status_code * 100) - - if json_response and 'error' in json_response: - error = json_response['error'] - try: - raise AblyException( - message=error['message'], - status_code=error['statusCode'], - code=int(error['code']), - ) - except KeyError: - msg = "Unexpected exception decoding server response: %s" - msg = msg % response.text - raise AblyException(message=msg, status_code=500, code=50000) - - raise AblyException(message="", - status_code=response.status_code, - code=response.status_code * 100) - - @staticmethod - def from_exception(e): - if isinstance(e, AblyException): - return e - return AblyException("Unexpected exception: %s" % e, 500, 50000) - - @staticmethod - def from_dict(value: dict): - return AblyException(value.get('message'), value.get('statusCode'), value.get('code')) - - -def catch_all(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - log.exception(e) - raise AblyException.from_exception(e) - - return wrapper - - -class AblyAuthException(AblyException): - pass - - -class IncompatibleClientIdException(AblyException): - pass diff --git a/ably/sync/util/helper.py b/ably/sync/util/helper.py deleted file mode 100644 index a844204e..00000000 --- a/ably/sync/util/helper.py +++ /dev/null @@ -1,42 +0,0 @@ -import inspect -import random -import string -import asyncio -import time -from typing import Callable - - -def get_random_id(): - # get random string of letters and digits - source = string.ascii_letters + string.digits - random_id = ''.join((random.choice(source) for i in range(8))) - return random_id - - -def is_callable_or_coroutine(value): - return asyncio.iscoroutinefunction(value) or inspect.isfunction(value) or inspect.ismethod(value) - - -def unix_time_ms(): - return round(time.time_ns() / 1_000_000) - - -def is_token_error(exception): - return 40140 <= exception.code < 40150 - - -class Timer: - def __init__(self, timeout: float, callback: Callable): - self._timeout = timeout - self._callback = callback - self._task = asyncio.create_task(self._job()) - - def _job(self): - asyncio.sleep(self._timeout / 1000) - if asyncio.iscoroutinefunction(self._callback): - self._callback() - else: - self._callback() - - def cancel(self): - self._task.cancel() diff --git a/ably/sync/util/nocrypto.py b/ably/sync/util/nocrypto.py deleted file mode 100644 index a66669b3..00000000 --- a/ably/sync/util/nocrypto.py +++ /dev/null @@ -1,9 +0,0 @@ - -class InstallPycrypto: - def __getattr__(self, name): - raise ImportError( - "This requires to install ably with crypto support: pip install 'ably[crypto]'" - ) - - -AES = Random = InstallPycrypto() diff --git a/test/ably/sync/rest/sync_encoders_test.py b/test/ably/sync/rest/sync_encoders_test.py deleted file mode 100644 index d70b22d3..00000000 --- a/test/ably/sync/rest/sync_encoders_test.py +++ /dev/null @@ -1,456 +0,0 @@ -import base64 -import json -import logging -import sys - -import mock -import msgpack - -from ably.sync import CipherParams -from ably.sync.util.crypto import get_cipher -from ably.sync.types.message import Message - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import BaseAsyncTestCase - -if sys.version_info >= (3, 8): - from unittest.mock import Mock -else: - from mock import Mock - -log = logging.getLogger(__name__) - - -class TestTextEncodersNoEncryption(BaseAsyncTestCase): - def setUp(self): - self.ably = TestApp.get_ably_rest(use_binary_protocol=False) - - def tearDown(self): - self.ably.close() - - def test_text_utf8(self): - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', 'foó') - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['data'] == 'foó' - assert not json.loads(kwargs['body']).get('encoding', '') - - def test_str(self): - # This test only makes sense for py2 - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', 'foo') - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['data'] == 'foo' - assert not json.loads(kwargs['body']).get('encoding', '') - - def test_with_binary_type(self): - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', bytearray(b'foo')) - _, kwargs = post_mock.call_args - raw_data = json.loads(kwargs['body'])['data'] - assert base64.b64decode(raw_data.encode('ascii')) == bytearray(b'foo') - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'base64' - - def test_with_bytes_type(self): - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', b'foo') - _, kwargs = post_mock.call_args - raw_data = json.loads(kwargs['body'])['data'] - assert base64.b64decode(raw_data.encode('ascii')) == bytearray(b'foo') - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'base64' - - def test_with_json_dict_data(self): - channel = self.ably.channels["persisted:publish"] - data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - raw_data = json.loads(json.loads(kwargs['body'])['data']) - assert raw_data == data - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json' - - def test_with_json_list_data(self): - channel = self.ably.channels["persisted:publish"] - data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - raw_data = json.loads(json.loads(kwargs['body'])['data']) - assert raw_data == data - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json' - - def test_text_utf8_decode(self): - channel = self.ably.channels["persisted:stringdecode"] - - channel.publish('event', 'fóo') - history = channel.history() - message = history.items[0] - assert message.data == 'fóo' - assert isinstance(message.data, str) - assert not message.encoding - - def test_text_str_decode(self): - channel = self.ably.channels["persisted:stringnonutf8decode"] - - channel.publish('event', 'foo') - history = channel.history() - message = history.items[0] - assert message.data == 'foo' - assert isinstance(message.data, str) - assert not message.encoding - - def test_with_binary_type_decode(self): - channel = self.ably.channels["persisted:binarydecode"] - - channel.publish('event', bytearray(b'foob')) - history = channel.history() - message = history.items[0] - assert message.data == bytearray(b'foob') - assert isinstance(message.data, bytearray) - assert not message.encoding - - def test_with_json_dict_data_decode(self): - channel = self.ably.channels["persisted:jsondict"] - data = {'foó': 'bár'} - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - def test_with_json_list_data_decode(self): - channel = self.ably.channels["persisted:jsonarray"] - data = ['foó', 'bár'] - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - def test_decode_with_invalid_encoding(self): - data = 'foó' - encoded = base64.b64encode(data.encode('utf-8')) - decoded_data = Message.decode(encoded, 'foo/bar/utf-8/base64') - assert decoded_data['data'] == data - assert decoded_data['encoding'] == 'foo/bar' - - -class TestTextEncodersEncryption(BaseAsyncTestCase): - def setUp(self): - self.ably = TestApp.get_ably_rest(use_binary_protocol=False) - self.cipher_params = CipherParams(secret_key='keyfordecrypt_16', - algorithm='aes') - - def tearDown(self): - self.ably.close() - - def decrypt(self, payload, options=None): - if options is None: - options = {} - ciphertext = base64.b64decode(payload.encode('ascii')) - cipher = get_cipher({'key': b'keyfordecrypt_16'}) - return cipher.decrypt(ciphertext) - - def test_text_utf8(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', 'fóo') - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc/base64' - data = self.decrypt(json.loads(kwargs['body'])['data']).decode('utf-8') - assert data == 'fóo' - - def test_str(self): - # This test only makes sense for py2 - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', 'foo') - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['data'] == 'foo' - assert not json.loads(kwargs['body']).get('encoding', '') - - def test_with_binary_type(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', bytearray(b'foo')) - _, kwargs = post_mock.call_args - - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'cipher+aes-128-cbc/base64' - data = self.decrypt(json.loads(kwargs['body'])['data']) - assert data == bytearray(b'foo') - assert isinstance(data, bytearray) - - def test_with_json_dict_data(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' - raw_data = self.decrypt(json.loads(kwargs['body'])['data']).decode('ascii') - assert json.loads(raw_data) == data - - def test_with_json_list_data(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.HttpSync.post', new_callable=Mock) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - assert json.loads(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc/base64' - raw_data = self.decrypt(json.loads(kwargs['body'])['data']).decode('ascii') - assert json.loads(raw_data) == data - - def test_text_utf8_decode(self): - channel = self.ably.channels.get("persisted:enc_stringdecode", - cipher=self.cipher_params) - channel.publish('event', 'foó') - history = channel.history() - message = history.items[0] - assert message.data == 'foó' - assert isinstance(message.data, str) - assert not message.encoding - - def test_with_binary_type_decode(self): - channel = self.ably.channels.get("persisted:enc_binarydecode", - cipher=self.cipher_params) - - channel.publish('event', bytearray(b'foob')) - history = channel.history() - message = history.items[0] - assert message.data == bytearray(b'foob') - assert isinstance(message.data, bytearray) - assert not message.encoding - - def test_with_json_dict_data_decode(self): - channel = self.ably.channels.get("persisted:enc_jsondict", - cipher=self.cipher_params) - data = {'foó': 'bár'} - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - def test_with_json_list_data_decode(self): - channel = self.ably.channels.get("persisted:enc_list", - cipher=self.cipher_params) - data = ['foó', 'bár'] - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - -class TestBinaryEncodersNoEncryption(BaseAsyncTestCase): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def decode(self, data): - return msgpack.unpackb(data) - - def test_text_utf8(self): - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', 'foó') - _, kwargs = post_mock.call_args - assert self.decode(kwargs['body'])['data'] == 'foó' - assert self.decode(kwargs['body']).get('encoding', '').strip('/') == '' - - def test_with_binary_type(self): - channel = self.ably.channels["persisted:publish"] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', bytearray(b'foo')) - _, kwargs = post_mock.call_args - assert self.decode(kwargs['body'])['data'] == bytearray(b'foo') - assert self.decode(kwargs['body']).get('encoding', '').strip('/') == '' - - def test_with_json_dict_data(self): - channel = self.ably.channels["persisted:publish"] - data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - raw_data = json.loads(self.decode(kwargs['body'])['data']) - assert raw_data == data - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json' - - def test_with_json_list_data(self): - channel = self.ably.channels["persisted:publish"] - data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - raw_data = json.loads(self.decode(kwargs['body'])['data']) - assert raw_data == data - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json' - - def test_text_utf8_decode(self): - channel = self.ably.channels["persisted:stringdecode-bin"] - - channel.publish('event', 'fóo') - history = channel.history() - message = history.items[0] - assert message.data == 'fóo' - assert isinstance(message.data, str) - assert not message.encoding - - def test_with_binary_type_decode(self): - channel = self.ably.channels["persisted:binarydecode-bin"] - - channel.publish('event', bytearray(b'foob')) - history = channel.history() - message = history.items[0] - assert message.data == bytearray(b'foob') - assert not message.encoding - - def test_with_json_dict_data_decode(self): - channel = self.ably.channels["persisted:jsondict-bin"] - data = {'foó': 'bár'} - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - def test_with_json_list_data_decode(self): - channel = self.ably.channels["persisted:jsonarray-bin"] - data = ['foó', 'bár'] - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - -class TestBinaryEncodersEncryption(BaseAsyncTestCase): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.cipher_params = CipherParams(secret_key='keyfordecrypt_16', algorithm='aes') - - def tearDown(self): - self.ably.close() - - def decrypt(self, payload, options=None): - if options is None: - options = {} - cipher = get_cipher({'key': b'keyfordecrypt_16'}) - return cipher.decrypt(payload) - - def decode(self, data): - return msgpack.unpackb(data) - - def test_text_utf8(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', 'fóo') - _, kwargs = post_mock.call_args - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'utf-8/cipher+aes-128-cbc' - data = self.decrypt(self.decode(kwargs['body'])['data']).decode('utf-8') - assert data == 'fóo' - - def test_with_binary_type(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', bytearray(b'foo')) - _, kwargs = post_mock.call_args - - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'cipher+aes-128-cbc' - data = self.decrypt(self.decode(kwargs['body'])['data']) - assert data == bytearray(b'foo') - assert isinstance(data, bytearray) - - def test_with_json_dict_data(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - data = {'foó': 'bár'} - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc' - raw_data = self.decrypt(self.decode(kwargs['body'])['data']).decode('ascii') - assert json.loads(raw_data) == data - - def test_with_json_list_data(self): - channel = self.ably.channels.get("persisted:publish_enc", - cipher=self.cipher_params) - data = ['foó', 'bár'] - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish('event', data) - _, kwargs = post_mock.call_args - assert self.decode(kwargs['body'])['encoding'].strip('/') == 'json/utf-8/cipher+aes-128-cbc' - raw_data = self.decrypt(self.decode(kwargs['body'])['data']).decode('ascii') - assert json.loads(raw_data) == data - - def test_text_utf8_decode(self): - channel = self.ably.channels.get("persisted:enc_stringdecode-bin", - cipher=self.cipher_params) - channel.publish('event', 'foó') - history = channel.history() - message = history.items[0] - assert message.data == 'foó' - assert isinstance(message.data, str) - assert not message.encoding - - def test_with_binary_type_decode(self): - channel = self.ably.channels.get("persisted:enc_binarydecode-bin", - cipher=self.cipher_params) - - channel.publish('event', bytearray(b'foob')) - history = channel.history() - message = history.items[0] - assert message.data == bytearray(b'foob') - assert isinstance(message.data, bytearray) - assert not message.encoding - - def test_with_json_dict_data_decode(self): - channel = self.ably.channels.get("persisted:enc_jsondict-bin", - cipher=self.cipher_params) - data = {'foó': 'bár'} - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding - - def test_with_json_list_data_decode(self): - channel = self.ably.channels.get("persisted:enc_list-bin", - cipher=self.cipher_params) - data = ['foó', 'bár'] - channel.publish('event', data) - history = channel.history() - message = history.items[0] - assert message.data == data - assert not message.encoding diff --git a/test/ably/sync/rest/sync_restauth_test.py b/test/ably/sync/rest/sync_restauth_test.py deleted file mode 100644 index e4f3560b..00000000 --- a/test/ably/sync/rest/sync_restauth_test.py +++ /dev/null @@ -1,652 +0,0 @@ -import logging -import sys -import time -import uuid -import base64 - -from urllib.parse import parse_qs -import mock -import pytest -import respx -from httpx import Response, Client - -import ably -from ably.sync import AblyRestSync -from ably.sync import AuthSync -from ably.sync import AblyAuthException -from ably.sync.types.tokendetails import TokenDetails - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - -if sys.version_info >= (3, 8): - from unittest.mock import Mock -else: - from mock import Mock - -log = logging.getLogger(__name__) - - -# does not make any request, no need to vary by protocol -class TestAuth(BaseAsyncTestCase): - def setUp(self): - self.test_vars = TestApp.get_test_vars() - - def test_auth_init_key_only(self): - ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"]) - assert AuthSync.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - assert ably.auth.auth_options.key_name == self.test_vars["keys"][0]['key_name'] - assert ably.auth.auth_options.key_secret == self.test_vars["keys"][0]['key_secret'] - - def test_auth_init_token_only(self): - ably = AblyRestSync(token="this_is_not_really_a_token") - - assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - - def test_auth_token_details(self): - td = TokenDetails() - ably = AblyRestSync(token_details=td) - - assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism - assert ably.auth.token_details is td - - def test_auth_init_with_token_callback(self): - callback_called = [] - - def token_callback(token_params): - callback_called.append(True) - return "this_is_not_really_a_token_request" - - ably = TestApp.get_ably_rest( - key=None, - key_name=self.test_vars["keys"][0]["key_name"], - auth_callback=token_callback) - - try: - ably.stats(None) - except Exception: - pass - - assert callback_called, "Token callback not called" - assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - - def test_auth_init_with_key_and_client_id(self): - ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], client_id='testClientId') - - assert AuthSync.Method.BASIC == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - assert ably.auth.client_id == 'testClientId' - - def test_auth_init_with_token(self): - ably = TestApp.get_ably_rest(key=None, token="this_is_not_really_a_token") - assert AuthSync.Method.TOKEN == ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - - # RSA11 - def test_request_basic_auth_header(self): - ably = AblyRestSync(key_secret='foo', key_name='bar') - - with mock.patch.object(Client, 'send') as get_mock: - try: - ably.http.get('/time', skip_auth=False) - except Exception: - pass - request = get_mock.call_args_list[0][0][0] - authorization = request.headers['Authorization'] - assert authorization == 'Basic %s' % base64.b64encode('bar:foo'.encode('ascii')).decode('utf-8') - - # RSA7e2 - def test_request_basic_auth_header_with_client_id(self): - ably = AblyRestSync(key_secret='foo', key_name='bar', client_id='client_id') - - with mock.patch.object(Client, 'send') as get_mock: - try: - ably.http.get('/time', skip_auth=False) - except Exception: - pass - request = get_mock.call_args_list[0][0][0] - client_id = request.headers['x-ably-clientid'] - assert client_id == base64.b64encode('client_id'.encode('ascii')).decode('utf-8') - - def test_request_token_auth_header(self): - ably = AblyRestSync(token='not_a_real_token') - - with mock.patch.object(Client, 'send') as get_mock: - try: - ably.http.get('/time', skip_auth=False) - except Exception: - pass - request = get_mock.call_args_list[0][0][0] - authorization = request.headers['Authorization'] - assert authorization == 'Bearer %s' % base64.b64encode('not_a_real_token'.encode('ascii')).decode('utf-8') - - def test_if_cant_authenticate_via_token(self): - with pytest.raises(ValueError): - AblyRestSync(use_token_auth=True) - - def test_use_auth_token(self): - ably = AblyRestSync(use_token_auth=True, key=self.test_vars["keys"][0]["key_str"]) - assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN - - def test_with_client_id(self): - ably = AblyRestSync(use_token_auth=True, client_id='client_id', key=self.test_vars["keys"][0]["key_str"]) - assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN - - def test_with_auth_url(self): - ably = AblyRestSync(auth_url='auth_url') - assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN - - def test_with_auth_callback(self): - ably = AblyRestSync(auth_callback=lambda x: x) - assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN - - def test_with_token(self): - ably = AblyRestSync(token='a token') - assert ably.auth.auth_mechanism == AuthSync.Method.TOKEN - - def test_default_ttl_is_1hour(self): - one_hour_in_ms = 60 * 60 * 1000 - assert TokenDetails.DEFAULTS['ttl'] == one_hour_in_ms - - def test_with_auth_method(self): - ably = AblyRestSync(token='a token', auth_method='POST') - assert ably.auth.auth_options.auth_method == 'POST' - - def test_with_auth_headers(self): - ably = AblyRestSync(token='a token', auth_headers={'h1': 'v1'}) - assert ably.auth.auth_options.auth_headers == {'h1': 'v1'} - - def test_with_auth_params(self): - ably = AblyRestSync(token='a token', auth_params={'p': 'v'}) - assert ably.auth.auth_options.auth_params == {'p': 'v'} - - def test_with_default_token_params(self): - ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], - default_token_params={'ttl': 12345}) - assert ably.auth.auth_options.default_token_params == {'ttl': 12345} - - -class TestAuthAuthorize(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.test_vars = TestApp.get_test_vars() - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def test_if_authorize_changes_auth_mechanism_to_token(self): - assert AuthSync.Method.BASIC == self.ably.auth.auth_mechanism, "Unexpected Auth method mismatch" - - self.ably.auth.authorize() - - assert AuthSync.Method.TOKEN == self.ably.auth.auth_mechanism, "Authorize should change the Auth method" - - # RSA10a - @dont_vary_protocol - def test_authorize_always_creates_new_token(self): - self.ably.auth.authorize({'capability': {'test': ['publish']}}) - self.ably.channels.test.publish('event', 'data') - - self.ably.auth.authorize({'capability': {'test': ['subscribe']}}) - with pytest.raises(AblyAuthException): - self.ably.channels.test.publish('event', 'data') - - def test_authorize_create_new_token_if_expired(self): - token = self.ably.auth.authorize() - with mock.patch('ably.rest.auth.Auth.token_details_has_expired', - return_value=True): - new_token = self.ably.auth.authorize() - - assert token is not new_token - - def test_authorize_returns_a_token_details(self): - token = self.ably.auth.authorize() - assert isinstance(token, TokenDetails) - - @dont_vary_protocol - def test_authorize_adheres_to_request_token(self): - token_params = {'ttl': 10, 'client_id': 'client_id'} - auth_params = {'auth_url': 'somewhere.com', 'query_time': True} - with mock.patch('ably.sync.rest.auth.AuthSync.request_token', new_callable=Mock) as request_mock: - self.ably.auth.authorize(token_params, auth_params) - - token_called, auth_called = request_mock.call_args - assert token_called[0] == token_params - - # Authorize may call request_token with some default auth_options. - for arg, value in auth_params.items(): - assert auth_called[arg] == value, "%s called with wrong value: %s" % (arg, value) - - def test_with_token_str_https(self): - token = self.ably.auth.authorize() - token = token.token - ably = TestApp.get_ably_rest(key=None, token=token, tls=True, - use_binary_protocol=self.use_binary_protocol) - ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') - ably.close() - - def test_with_token_str_http(self): - token = self.ably.auth.authorize() - token = token.token - ably = TestApp.get_ably_rest(key=None, token=token, tls=False, - use_binary_protocol=self.use_binary_protocol) - ably.channels.test_auth_with_token_str.publish('event', 'foo_bar') - ably.close() - - def test_if_default_client_id_is_used(self): - ably = TestApp.get_ably_rest(client_id='my_client_id', - use_binary_protocol=self.use_binary_protocol) - token = ably.auth.authorize() - assert token.client_id == 'my_client_id' - ably.close() - - # RSA10j - def test_if_parameters_are_stored_and_used_as_defaults(self): - # Define some parameters - auth_options = dict(self.ably.auth.auth_options.auth_options) - auth_options['auth_headers'] = {'a_headers': 'a_value'} - self.ably.auth.authorize({'ttl': 555}, auth_options) - with mock.patch('ably.sync.rest.auth.AuthSync.request_token', - wraps=self.ably.auth.request_token) as request_mock: - self.ably.auth.authorize() - - token_called, auth_called = request_mock.call_args - assert token_called[0] == {'ttl': 555} - assert auth_called['auth_headers'] == {'a_headers': 'a_value'} - - # Different parameters, should completely replace the first ones, not merge - auth_options = dict(self.ably.auth.auth_options.auth_options) - auth_options['auth_headers'] = None - self.ably.auth.authorize({}, auth_options) - with mock.patch('ably.sync.rest.auth.AuthSync.request_token', - wraps=self.ably.auth.request_token) as request_mock: - self.ably.auth.authorize() - - token_called, auth_called = request_mock.call_args - assert token_called[0] == {} - assert auth_called['auth_headers'] is None - - # RSA10g - def test_timestamp_is_not_stored(self): - # authorize once with arbitrary defaults - auth_options = dict(self.ably.auth.auth_options.auth_options) - auth_options['auth_headers'] = {'a_headers': 'a_value'} - token_1 = self.ably.auth.authorize( - {'ttl': 60 * 1000, 'client_id': 'new_id'}, - auth_options) - assert isinstance(token_1, TokenDetails) - - # call authorize again with timestamp set - timestamp = self.ably.time() - with mock.patch('ably.sync.rest.auth.TokenRequest', - wraps=ably.types.tokenrequest.TokenRequest) as tr_mock: - auth_options = dict(self.ably.auth.auth_options.auth_options) - auth_options['auth_headers'] = {'a_headers': 'a_value'} - token_2 = self.ably.auth.authorize( - {'ttl': 60 * 1000, 'client_id': 'new_id', 'timestamp': timestamp}, - auth_options) - assert isinstance(token_2, TokenDetails) - assert token_1 != token_2 - assert tr_mock.call_args[1]['timestamp'] == timestamp - - # call authorize again with no params - with mock.patch('ably.sync.rest.auth.TokenRequest', - wraps=ably.types.tokenrequest.TokenRequest) as tr_mock: - token_4 = self.ably.auth.authorize() - assert isinstance(token_4, TokenDetails) - assert token_2 != token_4 - assert tr_mock.call_args[1]['timestamp'] != timestamp - - def test_client_id_precedence(self): - client_id = uuid.uuid4().hex - overridden_client_id = uuid.uuid4().hex - ably = TestApp.get_ably_rest( - use_binary_protocol=self.use_binary_protocol, - client_id=client_id, - default_token_params={'client_id': overridden_client_id}) - token = ably.auth.authorize() - assert token.client_id == client_id - assert ably.auth.client_id == client_id - - channel = ably.channels[ - self.get_channel_name('test_client_id_precedence')] - channel.publish('test', 'data') - history = channel.history() - assert history.items[0].client_id == client_id - ably.close() - - -class TestRequestToken(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - - def per_protocol_setup(self, use_binary_protocol): - self.use_binary_protocol = use_binary_protocol - - def test_with_key(self): - ably = TestApp.get_ably_rest(use_binary_protocol=self.use_binary_protocol) - - token_details = ably.auth.request_token() - assert isinstance(token_details, TokenDetails) - ably.close() - - ably = TestApp.get_ably_rest(key=None, token_details=token_details, - use_binary_protocol=self.use_binary_protocol) - channel = self.get_channel_name('test_request_token_with_key') - - ably.channels[channel].publish('event', 'foo') - - history = ably.channels[channel].history() - assert history.items[0].data == 'foo' - ably.close() - - @dont_vary_protocol - @respx.mock - def test_with_auth_url_headers_and_params_http_post(self): # noqa: N802 - url = 'http://www.example.com' - headers = {'foo': 'bar'} - ably = TestApp.get_ably_rest(key=None, auth_url=url) - - auth_params = {'foo': 'auth', 'spam': 'eggs'} - token_params = {'foo': 'token'} - auth_route = respx.post(url) - - def call_back(request): - assert request.headers['content-type'] == 'application/x-www-form-urlencoded' - assert headers['foo'] == request.headers['foo'] - - # TokenParams has precedence - assert parse_qs(request.content.decode('utf-8')) == {'foo': ['token'], 'spam': ['eggs']} - return Response( - status_code=200, - content="token_string", - headers={ - "Content-Type": "text/plain", - } - ) - - auth_route.side_effect = call_back - token_details = ably.auth.request_token( - token_params=token_params, auth_url=url, auth_headers=headers, - auth_method='POST', auth_params=auth_params) - - assert 1 == auth_route.called - assert isinstance(token_details, TokenDetails) - assert 'token_string' == token_details.token - ably.close() - - @dont_vary_protocol - @respx.mock - def test_with_auth_url_headers_and_params_http_get(self): # noqa: N802 - url = 'http://www.example.com' - headers = {'foo': 'bar'} - ably = TestApp.get_ably_rest( - key=None, auth_url=url, - auth_headers={'this': 'will_not_be_used'}, - auth_params={'this': 'will_not_be_used'}) - - auth_params = {'foo': 'auth', 'spam': 'eggs'} - token_params = {'foo': 'token'} - auth_route = respx.get(url, params={'foo': ['token'], 'spam': ['eggs']}) - - def call_back(request): - assert request.headers['foo'] == 'bar' - assert 'this' not in request.headers - assert not request.content - - return Response( - status_code=200, - json={'issued': 1, 'token': 'another_token_string'} - ) - auth_route.side_effect = call_back - token_details = ably.auth.request_token( - token_params=token_params, auth_url=url, auth_headers=headers, - auth_params=auth_params) - assert 'another_token_string' == token_details.token - ably.close() - - @dont_vary_protocol - def test_with_callback(self): - called_token_params = {'ttl': '3600000'} - - def callback(token_params): - assert token_params == called_token_params - return 'token_string' - - ably = TestApp.get_ably_rest(key=None, auth_callback=callback) - - token_details = ably.auth.request_token( - token_params=called_token_params, auth_callback=callback) - assert isinstance(token_details, TokenDetails) - assert 'token_string' == token_details.token - - def callback(token_params): - assert token_params == called_token_params - return TokenDetails(token='another_token_string') - - token_details = ably.auth.request_token( - token_params=called_token_params, auth_callback=callback) - assert 'another_token_string' == token_details.token - ably.close() - - @dont_vary_protocol - @respx.mock - def test_when_auth_url_has_query_string(self): - url = 'http://www.example.com?with=query' - headers = {'foo': 'bar'} - ably = TestApp.get_ably_rest(key=None, auth_url=url) - auth_route = respx.get('http://www.example.com', params={'with': 'query', 'spam': 'eggs'}).mock( - return_value=Response(status_code=200, content='token_string', headers={"Content-Type": "text/plain"})) - ably.auth.request_token(auth_url=url, - auth_headers=headers, - auth_params={'spam': 'eggs'}) - assert auth_route.called - ably.close() - - @dont_vary_protocol - def test_client_id_null_for_anonymous_auth(self): - ably = TestApp.get_ably_rest( - key=None, - key_name=self.test_vars["keys"][0]["key_name"], - key_secret=self.test_vars["keys"][0]["key_secret"]) - token = ably.auth.authorize() - - assert isinstance(token, TokenDetails) - assert token.client_id is None - assert ably.auth.client_id is None - ably.close() - - @dont_vary_protocol - def test_client_id_null_until_auth(self): - client_id = uuid.uuid4().hex - token_ably = TestApp.get_ably_rest( - default_token_params={'client_id': client_id}) - # before auth, client_id is None - assert token_ably.auth.client_id is None - - token = token_ably.auth.authorize() - assert isinstance(token, TokenDetails) - - # after auth, client_id is defined - assert token.client_id == client_id - assert token_ably.auth.client_id == client_id - token_ably.close() - - -class TestRenewToken(BaseAsyncTestCase): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.host = 'fake-host.ably.io' - self.ably = TestApp.get_ably_rest(use_binary_protocol=False, rest_host=self.host) - # with headers - self.publish_attempts = 0 - self.channel = uuid.uuid4().hex - tokens = ['a_token', 'another_token'] - headers = {'Content-Type': 'application/json'} - self.mocked_api = respx.mock(base_url='https://{}'.format(self.host)) - self.request_token_route = self.mocked_api.post( - "/keys/{}/requestToken".format(self.test_vars["keys"][0]['key_name']), - name="request_token_route") - self.request_token_route.return_value = Response( - status_code=200, - headers=headers, - json={ - 'token': tokens[self.request_token_route.call_count - 1], - 'expires': (time.time() + 60) * 1000 - }, - ) - - def call_back(request): - self.publish_attempts += 1 - if self.publish_attempts in [1, 3]: - return Response( - status_code=201, - headers=headers, - json=[], - ) - return Response( - status_code=401, - headers=headers, - json={ - 'error': {'message': 'Authentication failure', 'statusCode': 401, 'code': 40140} - }, - ) - - self.publish_attempt_route = self.mocked_api.post("/channels/{}/messages".format(self.channel), - name="publish_attempt_route") - self.publish_attempt_route.side_effect = call_back - self.mocked_api.start() - - def tearDown(self): - # We need to have quiet here in order to do not have check if all endpoints were called - self.mocked_api.stop(quiet=True) - self.mocked_api.reset() - self.ably.close() - - # RSA4b - def test_when_renewable(self): - self.ably.auth.authorize() - self.ably.channels[self.channel].publish('evt', 'msg') - assert self.mocked_api["request_token_route"].call_count == 1 - assert self.publish_attempts == 1 - - # Triggers an authentication 401 failure which should automatically request a new token - self.ably.channels[self.channel].publish('evt', 'msg') - assert self.mocked_api["request_token_route"].call_count == 2 - assert self.publish_attempts == 3 - - # RSA4a - def test_when_not_renewable(self): - self.ably.close() - - self.ably = TestApp.get_ably_rest( - key=None, - rest_host=self.host, - token='token ID cannot be used to create a new token', - use_binary_protocol=False) - self.ably.channels[self.channel].publish('evt', 'msg') - assert self.publish_attempts == 1 - - publish = self.ably.channels[self.channel].publish - - match = "Need a new token but auth_options does not include a way to request one" - with pytest.raises(AblyAuthException, match=match): - publish('evt', 'msg') - - assert not self.mocked_api["request_token_route"].called - - # RSA4a - def test_when_not_renewable_with_token_details(self): - token_details = TokenDetails(token='a_dummy_token') - self.ably = TestApp.get_ably_rest( - key=None, - rest_host=self.host, - token_details=token_details, - use_binary_protocol=False) - self.ably.channels[self.channel].publish('evt', 'msg') - assert self.mocked_api["publish_attempt_route"].call_count == 1 - - publish = self.ably.channels[self.channel].publish - - match = "Need a new token but auth_options does not include a way to request one" - with pytest.raises(AblyAuthException, match=match): - publish('evt', 'msg') - - assert not self.mocked_api["request_token_route"].called - - -class TestRenewExpiredToken(BaseAsyncTestCase): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.publish_attempts = 0 - self.channel = uuid.uuid4().hex - - self.host = 'fake-host.ably.io' - key = self.test_vars["keys"][0]['key_name'] - headers = {'Content-Type': 'application/json'} - - self.mocked_api = respx.mock(base_url='https://{}'.format(self.host)) - self.request_token_route = self.mocked_api.post("/keys/{}/requestToken".format(key), - name="request_token_route") - self.request_token_route.return_value = Response( - status_code=200, - headers=headers, - json={ - 'token': 'a_token', - 'expires': int(time.time() * 1000), # Always expires - } - ) - self.publish_message_route = self.mocked_api.post("/channels/{}/messages".format(self.channel), - name="publish_message_route") - self.time_route = self.mocked_api.get("/time", name="time_route") - self.time_route.return_value = Response( - status_code=200, - headers=headers, - json=[int(time.time() * 1000)] - ) - - def cb_publish(request): - self.publish_attempts += 1 - if self.publish_fail: - self.publish_fail = False - return Response( - status_code=401, - json={ - 'error': {'message': 'Authentication failure', 'statusCode': 401, 'code': 40140} - } - ) - return Response( - status_code=201, - json='[]' - ) - - self.publish_message_route.side_effect = cb_publish - self.mocked_api.start() - - def tearDown(self): - self.mocked_api.stop(quiet=True) - self.mocked_api.reset() - - # RSA4b1 - def test_query_time_false(self): - ably = TestApp.get_ably_rest(rest_host=self.host) - ably.auth.authorize() - self.publish_fail = True - ably.channels[self.channel].publish('evt', 'msg') - assert self.publish_attempts == 2 - ably.close() - - # RSA4b1 - def test_query_time_true(self): - ably = TestApp.get_ably_rest(query_time=True, rest_host=self.host) - ably.auth.authorize() - self.publish_fail = False - ably.channels[self.channel].publish('evt', 'msg') - assert self.publish_attempts == 1 - ably.close() diff --git a/test/ably/sync/rest/sync_restcapability_test.py b/test/ably/sync/rest/sync_restcapability_test.py deleted file mode 100644 index 224c5d66..00000000 --- a/test/ably/sync/rest/sync_restcapability_test.py +++ /dev/null @@ -1,242 +0,0 @@ -import pytest - -from ably.sync.types.capability import Capability -from ably.sync.util.exceptions import AblyException - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - - -class TestRestCapability(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - def test_blanket_intersection_with_key(self): - key = self.test_vars['keys'][1] - token_details = self.ably.auth.request_token(key_name=key['key_name'], key_secret=key['key_secret']) - expected_capability = Capability(key["capability"]) - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability." - - def test_equal_intersection_with_key(self): - key = self.test_vars['keys'][1] - - token_details = self.ably.auth.request_token( - key_name=key['key_name'], - key_secret=key['key_secret'], - token_params={'capability': key['capability']}) - - expected_capability = Capability(key["capability"]) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - @dont_vary_protocol - def test_empty_ops_intersection(self): - key = self.test_vars['keys'][1] - with pytest.raises(AblyException): - self.ably.auth.request_token( - key_name=key['key_name'], - key_secret=key['key_secret'], - token_params={'capability': {'testchannel': ['subscribe']}}) - - @dont_vary_protocol - def test_empty_paths_intersection(self): - key = self.test_vars['keys'][1] - with pytest.raises(AblyException): - self.ably.auth.request_token( - key_name=key['key_name'], - key_secret=key['key_secret'], - token_params={'capability': {"testchannelx": ["publish"]}}) - - def test_non_empty_ops_intersection(self): - key = self.test_vars['keys'][4] - - token_params = {"capability": { - "channel2": ["presence", "subscribe"] - }} - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - } - - expected_capability = Capability({ - "channel2": ["subscribe"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_non_empty_paths_intersection(self): - key = self.test_vars['keys'][4] - token_params = { - "capability": { - "channel2": ["presence", "subscribe"], - "channelx": ["presence", "subscribe"], - } - } - kwargs = { - "key_name": key["key_name"], - - "key_secret": key["key_secret"] - } - - expected_capability = Capability({ - "channel2": ["subscribe"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_wildcard_ops_intersection(self): - key = self.test_vars['keys'][4] - - token_params = { - "capability": { - "channel2": ["*"], - }, - } - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - } - - expected_capability = Capability({ - "channel2": ["subscribe", "publish"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_wildcard_ops_intersection_2(self): - key = self.test_vars['keys'][4] - - token_params = { - "capability": { - "channel6": ["publish", "subscribe"], - }, - } - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - } - - expected_capability = Capability({ - "channel6": ["subscribe", "publish"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_wildcard_resources_intersection(self): - key = self.test_vars['keys'][2] - - token_params = { - "capability": { - "cansubscribe": ["subscribe"], - }, - } - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - } - - expected_capability = Capability({ - "cansubscribe": ["subscribe"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_wildcard_resources_intersection_2(self): - key = self.test_vars['keys'][2] - - token_params = { - "capability": { - "cansubscribe:check": ["subscribe"], - }, - } - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - } - - expected_capability = Capability({ - "cansubscribe:check": ["subscribe"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - def test_wildcard_resources_intersection_3(self): - key = self.test_vars['keys'][2] - - token_params = { - "capability": { - "cansubscribe:*": ["subscribe"], - }, - } - kwargs = { - "key_name": key["key_name"], - "key_secret": key["key_secret"], - - } - - expected_capability = Capability({ - "cansubscribe:*": ["subscribe"] - }) - - token_details = self.ably.auth.request_token(token_params, **kwargs) - - assert token_details.token is not None, "Expected token" - assert expected_capability == token_details.capability, "Unexpected capability" - - @dont_vary_protocol - def test_invalid_capabilities(self): - with pytest.raises(AblyException) as excinfo: - self.ably.auth.request_token( - token_params={'capability': {"channel0": ["publish_"]}}) - - the_exception = excinfo.value - assert 400 == the_exception.status_code - assert 40000 == the_exception.code - - @dont_vary_protocol - def test_invalid_capabilities_2(self): - with pytest.raises(AblyException) as excinfo: - self.ably.auth.request_token( - token_params={'capability': {"channel0": ["*", "publish"]}}) - - the_exception = excinfo.value - assert 400 == the_exception.status_code - assert 40000 == the_exception.code - - @dont_vary_protocol - def test_invalid_capabilities_3(self): - with pytest.raises(AblyException) as excinfo: - self.ably.auth.request_token( - token_params={'capability': {"channel0": []}}) - - the_exception = excinfo.value - assert 400 == the_exception.status_code - assert 40000 == the_exception.code diff --git a/test/ably/sync/rest/sync_restchannelhistory_test.py b/test/ably/sync/rest/sync_restchannelhistory_test.py deleted file mode 100644 index 2263aeaa..00000000 --- a/test/ably/sync/rest/sync_restchannelhistory_test.py +++ /dev/null @@ -1,332 +0,0 @@ -import logging -import pytest -import respx - -from ably.sync import AblyException -from ably.sync.http.paginatedresult import PaginatedResultSync - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -class TestRestChannelHistory(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest(fallback_hosts=[]) - self.test_vars = TestApp.get_test_vars() - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - def test_channel_history_types(self): - history0 = self.get_channel('persisted:channelhistory_types') - - history0.publish('history0', 'This is a string message payload') - history0.publish('history1', b'This is a byte[] message payload') - history0.publish('history2', {'test': 'This is a JSONObject message payload'}) - history0.publish('history3', ['This is a JSONArray message payload']) - - history = history0.history() - assert isinstance(history, PaginatedResultSync) - messages = history.items - assert messages is not None, "Expected non-None messages" - assert 4 == len(messages), "Expected 4 messages" - - message_contents = {m.name: m for m in messages} - assert "This is a string message payload" == message_contents["history0"].data, \ - "Expect history0 to be expected String)" - assert b"This is a byte[] message payload" == message_contents["history1"].data, \ - "Expect history1 to be expected byte[]" - assert {"test": "This is a JSONObject message payload"} == message_contents["history2"].data, \ - "Expect history2 to be expected JSONObject" - assert ["This is a JSONArray message payload"] == message_contents["history3"].data, \ - "Expect history3 to be expected JSONObject" - - expected_message_history = [ - message_contents['history3'], - message_contents['history2'], - message_contents['history1'], - message_contents['history0'], - ] - assert expected_message_history == messages, "Expect messages in reverse order" - - def test_channel_history_multi_50_forwards(self): - history0 = self.get_channel('persisted:channelhistory_multi_50_f') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='forwards') - assert history is not None - messages = history.items - assert len(messages) == 50, "Expected 50 messages" - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(50)] - assert messages == expected_messages, 'Expect messages in forward order' - - def test_channel_history_multi_50_backwards(self): - history0 = self.get_channel('persisted:channelhistory_multi_50_b') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='backwards') - assert history is not None - messages = history.items - assert 50 == len(messages), "Expected 50 messages" - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(49, -1, -1)] - assert expected_messages == messages, 'Expect messages in reverse order' - - def history_mock_url(self, channel_name): - kwargs = { - 'scheme': 'https' if self.test_vars['tls'] else 'http', - 'host': self.test_vars['host'], - 'channel_name': channel_name - } - port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] - if port == 80: - kwargs['port_sufix'] = '' - else: - kwargs['port_sufix'] = ':' + str(port) - url = '{scheme}://{host}{port_sufix}/channels/{channel_name}/messages' - return url.format(**kwargs) - - @respx.mock - @dont_vary_protocol - def test_channel_history_default_limit(self): - self.per_protocol_setup(True) - channel = self.ably.channels['persisted:channelhistory_limit'] - url = self.history_mock_url('persisted:channelhistory_limit') - self.respx_add_empty_msg_pack(url) - channel.history() - assert 'limit' not in respx.calls[0].request.url.params.keys() - - @respx.mock - @dont_vary_protocol - def test_channel_history_with_limits(self): - self.per_protocol_setup(True) - channel = self.ably.channels['persisted:channelhistory_limit'] - url = self.history_mock_url('persisted:channelhistory_limit') - self.respx_add_empty_msg_pack(url) - - channel.history(limit=500) - assert '500' in respx.calls[0].request.url.params.get('limit') - - channel.history(limit=1000) - assert '1000' in respx.calls[1].request.url.params.get('limit') - - @dont_vary_protocol - def test_channel_history_max_limit_is_1000(self): - channel = self.ably.channels['persisted:channelhistory_limit'] - with pytest.raises(AblyException): - channel.history(limit=1001) - - def test_channel_history_limit_forwards(self): - history0 = self.get_channel('persisted:channelhistory_limit_f') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='forwards', limit=25) - assert history is not None - messages = history.items - assert len(messages) == 25, "Expected 25 messages" - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(25)] - assert messages == expected_messages, 'Expect messages in forward order' - - def test_channel_history_limit_backwards(self): - history0 = self.get_channel('persisted:channelhistory_limit_b') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='backwards', limit=25) - assert history is not None - messages = history.items - assert len(messages) == 25, "Expected 25 messages" - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(49, 24, -1)] - assert messages == expected_messages, 'Expect messages in forward order' - - def test_channel_history_time_forwards(self): - history0 = self.get_channel('persisted:channelhistory_time_f') - - for i in range(20): - history0.publish('history%d' % i, str(i)) - - interval_start = self.ably.time() - - for i in range(20, 40): - history0.publish('history%d' % i, str(i)) - - interval_end = self.ably.time() - - for i in range(40, 60): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='forwards', start=interval_start, - end=interval_end) - - messages = history.items - assert 20 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(20, 40)] - assert expected_messages == messages, 'Expect messages in forward order' - - def test_channel_history_time_backwards(self): - history0 = self.get_channel('persisted:channelhistory_time_b') - - for i in range(20): - history0.publish('history%d' % i, str(i)) - - interval_start = self.ably.time() - - for i in range(20, 40): - history0.publish('history%d' % i, str(i)) - - interval_end = self.ably.time() - - for i in range(40, 60): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='backwards', start=interval_start, - end=interval_end) - - messages = history.items - assert 20 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(39, 19, -1)] - assert expected_messages, messages == 'Expect messages in reverse order' - - def test_channel_history_paginate_forwards(self): - history0 = self.get_channel('persisted:channelhistory_paginate_f') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='forwards', limit=10) - messages = history.items - - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(10, 20)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(20, 30)] - assert expected_messages == messages, 'Expected 10 messages' - - def test_channel_history_paginate_backwards(self): - history0 = self.get_channel('persisted:channelhistory_paginate_b') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='backwards', limit=10) - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(39, 29, -1)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(29, 19, -1)] - assert expected_messages == messages, 'Expected 10 messages' - - def test_channel_history_paginate_forwards_first(self): - history0 = self.get_channel('persisted:channelhistory_paginate_first_f') - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='forwards', limit=10) - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(10, 20)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.first() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(0, 10)] - assert expected_messages == messages, 'Expected 10 messages' - - def test_channel_history_paginate_backwards_rel_first(self): - history0 = self.get_channel('persisted:channelhistory_paginate_first_b') - - for i in range(50): - history0.publish('history%d' % i, str(i)) - - history = history0.history(direction='backwards', limit=10) - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.next() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(39, 29, -1)] - assert expected_messages == messages, 'Expected 10 messages' - - history = history.first() - messages = history.items - assert 10 == len(messages) - - message_contents = {m.name: m for m in messages} - expected_messages = [message_contents['history%d' % i] for i in range(49, 39, -1)] - assert expected_messages == messages, 'Expected 10 messages' diff --git a/test/ably/sync/rest/sync_restchannelpublish_test.py b/test/ably/sync/rest/sync_restchannelpublish_test.py deleted file mode 100644 index a44ab265..00000000 --- a/test/ably/sync/rest/sync_restchannelpublish_test.py +++ /dev/null @@ -1,568 +0,0 @@ -import base64 -import binascii -import json -import logging -import os -import uuid - -import httpx -import mock -import msgpack -import pytest - -from ably.sync import api_version -from ably.sync import AblyException, IncompatibleClientIdException -from ably.sync.rest.auth import AuthSync -from ably.sync.types.message import Message -from ably.sync.types.tokendetails import TokenDetails -from ably.sync.util import case -from test.ably.sync import utils - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -# Ignore library warning regarding client_id -@pytest.mark.filterwarnings('ignore::DeprecationWarning') -class TestRestChannelPublish(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.ably = TestApp.get_ably_rest() - self.client_id = uuid.uuid4().hex - self.ably_with_client_id = TestApp.get_ably_rest(client_id=self.client_id, use_token_auth=True) - - def tearDown(self): - self.ably.close() - self.ably_with_client_id.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.ably_with_client_id.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def test_publish_various_datatypes_text(self): - publish0 = self.ably.channels[ - self.get_channel_name('persisted:publish0')] - - publish0.publish("publish0", "This is a string message payload") - publish0.publish("publish1", b"This is a byte[] message payload") - publish0.publish("publish2", {"test": "This is a JSONObject message payload"}) - publish0.publish("publish3", ["This is a JSONArray message payload"]) - - # Get the history for this channel - history = publish0.history() - messages = history.items - assert messages is not None, "Expected non-None messages" - assert len(messages) == 4, "Expected 4 messages" - - message_contents = dict((m.name, m.data) for m in messages) - log.debug("message_contents: %s" % str(message_contents)) - - assert message_contents["publish0"] == "This is a string message payload", \ - "Expect publish0 to be expected String)" - - assert message_contents["publish1"] == b"This is a byte[] message payload", \ - "Expect publish1 to be expected byte[]. Actual: %s" % str(message_contents['publish1']) - - assert message_contents["publish2"] == {"test": "This is a JSONObject message payload"}, \ - "Expect publish2 to be expected JSONObject" - - assert message_contents["publish3"] == ["This is a JSONArray message payload"], \ - "Expect publish3 to be expected JSONObject" - - @dont_vary_protocol - def test_unsupported_payload_must_raise_exception(self): - channel = self.ably.channels["persisted:publish0"] - for data in [1, 1.1, True]: - with pytest.raises(AblyException): - channel.publish('event', data) - - def test_publish_message_list(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:message_list_channel')] - - expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] - - channel.publish(messages=expected_messages) - - # Get the history for this channel - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == len(expected_messages), "Expected 3 messages" - - for m, expected_m in zip(messages, reversed(expected_messages)): - assert m.name == expected_m.name - assert m.data == expected_m.data - - def test_message_list_generate_one_request(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:message_list_channel_one_request')] - - expected_messages = [Message("name-{}".format(i), str(i)) for i in range(3)] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish(messages=expected_messages) - assert post_mock.call_count == 1 - - if self.use_binary_protocol: - messages = msgpack.unpackb(post_mock.call_args[1]['body']) - else: - messages = json.loads(post_mock.call_args[1]['body']) - - for i, message in enumerate(messages): - assert message['name'] == 'name-' + str(i) - assert message['data'] == str(i) - - def test_publish_error(self): - ably = TestApp.get_ably_rest(use_binary_protocol=self.use_binary_protocol) - ably.auth.authorize( - token_params={'capability': {"only_subscribe": ["subscribe"]}}) - - with pytest.raises(AblyException) as excinfo: - ably.channels["only_subscribe"].publish() - - assert 401 == excinfo.value.status_code - assert 40160 == excinfo.value.code - ably.close() - - def test_publish_message_null_name(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:message_null_name_channel')] - - data = "String message" - channel.publish(name=None, data=data) - - # Get the history for this channel - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - assert messages[0].name is None - assert messages[0].data == data - - def test_publish_message_null_data(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:message_null_data_channel')] - - name = "Test name" - channel.publish(name=name, data=None) - - # Get the history for this channel - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - - assert messages[0].name == name - assert messages[0].data is None - - def test_publish_message_null_name_and_data(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:null_name_and_data_channel')] - - channel.publish(name=None, data=None) - channel.publish() - - # Get the history for this channel - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 2, "Expected 2 messages" - - for m in messages: - assert m.name is None - assert m.data is None - - def test_publish_message_null_name_and_data_keys_arent_sent(self): - channel = self.ably.channels[ - self.get_channel_name('persisted:null_name_and_data_keys_arent_sent_channel')] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish(name=None, data=None) - - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - - assert post_mock.call_count == 1 - - if self.use_binary_protocol: - posted_body = msgpack.unpackb(post_mock.call_args[1]['body']) - else: - posted_body = json.loads(post_mock.call_args[1]['body']) - - assert 'name' not in posted_body - assert 'data' not in posted_body - - def test_message_attr(self): - publish0 = self.ably.channels[ - self.get_channel_name('persisted:publish_message_attr')] - - messages = [Message('publish', - {"test": "This is a JSONObject message payload"}, - client_id='client_id')] - publish0.publish(messages=messages) - - # Get the history for this channel - history = publish0.history() - message = history.items[0] - assert isinstance(message, Message) - assert message.id - assert message.name - assert message.data == {'test': 'This is a JSONObject message payload'} - assert message.encoding == '' - assert message.client_id == 'client_id' - assert isinstance(message.timestamp, int) - - def test_token_is_bound_to_options_client_id_after_publish(self): - # null before publish - assert self.ably_with_client_id.auth.token_details is None - - # created after message publish and will have client_id - channel = self.ably_with_client_id.channels[ - self.get_channel_name('persisted:restricted_to_client_id')] - channel.publish(name='publish', data='test') - - # defined after publish - assert isinstance(self.ably_with_client_id.auth.token_details, TokenDetails) - assert self.ably_with_client_id.auth.token_details.client_id == self.client_id - assert self.ably_with_client_id.auth.auth_mechanism == AuthSync.Method.TOKEN - history = channel.history() - assert history.items[0].client_id == self.client_id - - def test_publish_message_without_client_id_on_identified_client(self): - channel = self.ably_with_client_id.channels[ - self.get_channel_name('persisted:no_client_id_identified_client')] - - with mock.patch('ably.sync.rest.rest.HttpSync.post', - wraps=channel.ably.http.post) as post_mock: - channel.publish(name='publish', data='test') - - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - - assert post_mock.call_count == 2 - - if self.use_binary_protocol: - posted_body = msgpack.unpackb( - post_mock.mock_calls[0][2]['body']) - else: - posted_body = json.loads( - post_mock.mock_calls[0][2]['body']) - - assert 'client_id' not in posted_body - - # Get the history for this channel - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - - assert messages[0].client_id == self.ably_with_client_id.client_id - - def test_publish_message_with_client_id_on_identified_client(self): - # works if same - channel = self.ably_with_client_id.channels[ - self.get_channel_name('persisted:with_client_id_identified_client')] - message = Message(name='publish', data='test', client_id=self.ably_with_client_id.client_id) - channel.publish(message) - - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 1, "Expected 1 message" - - assert messages[0].client_id == self.ably_with_client_id.client_id - - message = Message(name='publish', data='test', client_id='invalid') - # fails if different - with pytest.raises(IncompatibleClientIdException): - channel.publish(message) - - def test_publish_message_with_wrong_client_id_on_implicit_identified_client(self): - new_token = self.ably.auth.authorize(token_params={'client_id': uuid.uuid4().hex}) - new_ably = TestApp.get_ably_rest(key=None, - token=new_token.token, - use_binary_protocol=self.use_binary_protocol) - - channel = new_ably.channels[ - self.get_channel_name('persisted:wrong_client_id_implicit_client')] - - message = Message(name='publish', data='test', client_id='invalid') - with pytest.raises(AblyException) as excinfo: - channel.publish(message) - - assert 400 == excinfo.value.status_code - assert 40012 == excinfo.value.code - new_ably.close() - - # RSA15b - def test_wildcard_client_id_can_publish_as_others(self): - wildcard_token_details = self.ably.auth.request_token({'client_id': '*'}) - wildcard_ably = TestApp.get_ably_rest( - key=None, - token_details=wildcard_token_details, - use_binary_protocol=self.use_binary_protocol) - - assert wildcard_ably.auth.client_id == '*' - channel = wildcard_ably.channels[ - self.get_channel_name('persisted:wildcard_client_id')] - channel.publish(name='publish1', data='no client_id') - some_client_id = uuid.uuid4().hex - message = Message(name='publish2', data='some client_id', client_id=some_client_id) - channel.publish(message) - - history = channel.history() - messages = history.items - - assert messages is not None, "Expected non-None messages" - assert len(messages) == 2, "Expected 2 messages" - - assert messages[0].client_id == some_client_id - assert messages[1].client_id is None - - wildcard_ably.close() - - # TM2h - @dont_vary_protocol - def test_invalid_connection_key(self): - channel = self.ably.channels["persisted:invalid_connection_key"] - message = Message(data='payload', connection_key='should.be.wrong') - with pytest.raises(AblyException) as excinfo: - channel.publish(messages=[message]) - - assert 400 == excinfo.value.status_code - assert 40006 == excinfo.value.code - - # TM2i, RSL6a2, RSL1h - def test_publish_extras(self): - channel = self.ably.channels[ - self.get_channel_name('canpublish:extras_channel')] - extras = { - 'push': { - 'notification': {"title": "Testing"}, - } - } - message = Message(name='test-name', data='test-data', extras=extras) - channel.publish(message) - - # Get the history for this channel - history = channel.history() - message = history.items[0] - assert message.name == 'test-name' - assert message.data == 'test-data' - assert message.extras == extras - - # RSL6a1 - def test_interoperability(self): - name = self.get_channel_name('persisted:interoperability_channel') - channel = self.ably.channels[name] - - url = 'https://%s/channels/%s/messages' % (self.test_vars["host"], name) - key = self.test_vars['keys'][0] - auth = (key['key_name'], key['key_secret']) - - type_mapping = { - 'string': str, - 'jsonObject': dict, - 'jsonArray': list, - 'binary': bytearray, - } - - path = os.path.join(utils.get_submodule_dir(__file__), 'test-resources', 'messages-encoding.json') - with open(path) as f: - data = json.load(f) - for input_msg in data['messages']: - data = input_msg['data'] - encoding = input_msg['encoding'] - expected_type = input_msg['expectedType'] - if expected_type == 'binary': - expected_value = input_msg.get('expectedHexValue') - expected_value = expected_value.encode('ascii') - expected_value = binascii.a2b_hex(expected_value) - else: - expected_value = input_msg.get('expectedValue') - - # 1) - channel.publish(data=expected_value) - with httpx.Client(http2=True) as client: - r = client.get(url, auth=auth) - item = r.json()[0] - assert item.get('encoding') == encoding - if encoding == 'json': - assert json.loads(item['data']) == json.loads(data) - else: - assert item['data'] == data - - # 2) - channel.publish(messages=[Message(data=data, encoding=encoding)]) - history = channel.history() - message = history.items[0] - assert message.data == expected_value - assert type(message.data) == type_mapping[expected_type] - - # https://github.com/ably/ably-python/issues/130 - def test_publish_slash(self): - channel = self.ably.channels.get(self.get_channel_name('persisted:widgets/')) - name, data = 'Name', 'Data' - channel.publish(name, data) - history = channel.history() - assert len(history.items) == 1 - assert history.items[0].name == name - assert history.items[0].data == data - - # RSL1l - @dont_vary_protocol - def test_publish_params(self): - channel = self.ably.channels.get(self.get_channel_name()) - - message = Message('name', 'data') - with pytest.raises(AblyException) as excinfo: - channel.publish(message, {'_forceNack': True}) - - assert 400 == excinfo.value.status_code - assert 40099 == excinfo.value.code - - -class TestRestChannelPublishIdempotent(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.ably_idempotent = TestApp.get_ably_rest(idempotent_rest_publishing=True) - - def tearDown(self): - self.ably.close() - self.ably_idempotent.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - # TO3n - @dont_vary_protocol - def test_idempotent_rest_publishing(self): - # Test default value - if api_version < '1.2': - assert self.ably.options.idempotent_rest_publishing is False - else: - assert self.ably.options.idempotent_rest_publishing is True - - # Test setting value explicitly - ably = TestApp.get_ably_rest(idempotent_rest_publishing=True) - assert ably.options.idempotent_rest_publishing is True - ably.close() - - ably = TestApp.get_ably_rest(idempotent_rest_publishing=False) - assert ably.options.idempotent_rest_publishing is False - ably.close() - - # RSL1j - @dont_vary_protocol - def test_message_serialization(self): - channel = self.get_channel() - - data = { - 'name': 'name', - 'data': 'data', - 'client_id': 'client_id', - 'extras': {}, - 'id': 'foobar', - } - message = Message(**data) - request_body = channel._ChannelSync__publish_request_body(messages=[message]) - input_keys = set(case.snake_to_camel(x) for x in data.keys()) - assert input_keys - set(request_body) == set() - - # RSL1k1 - @dont_vary_protocol - def test_idempotent_library_generated(self): - channel = self.ably_idempotent.channels[self.get_channel_name()] - - message = Message('name', 'data') - request_body = channel._ChannelSync__publish_request_body(messages=[message]) - base_id, serial = request_body['id'].split(':') - assert len(base64.b64decode(base_id)) >= 9 - assert serial == '0' - - # RSL1k2 - @dont_vary_protocol - def test_idempotent_client_supplied(self): - channel = self.ably_idempotent.channels[self.get_channel_name()] - - message = Message('name', 'data', id='foobar') - request_body = channel._ChannelSync__publish_request_body(messages=[message]) - assert request_body['id'] == 'foobar' - - # RSL1k3 - @dont_vary_protocol - def test_idempotent_mixed_ids(self): - channel = self.ably_idempotent.channels[self.get_channel_name()] - - messages = [ - Message('name', 'data', id='foobar'), - Message('name', 'data'), - ] - request_body = channel._ChannelSync__publish_request_body(messages=messages) - assert request_body[0]['id'] == 'foobar' - assert 'id' not in request_body[1] - - def get_ably_rest(self, *args, **kwargs): - kwargs['use_binary_protocol'] = self.use_binary_protocol - return TestApp.get_ably_rest(*args, **kwargs) - - # RSL1k4 - def test_idempotent_library_generated_retry(self): - test_vars = TestApp.get_test_vars() - ably = self.get_ably_rest(idempotent_rest_publishing=True, fallback_hosts=[test_vars["host"]] * 3) - channel = ably.channels[self.get_channel_name()] - - state = {'failures': 0} - client = httpx.Client(http2=True) - send = client.send - - def side_effect(*args, **kwargs): - x = send(args[1]) - if state['failures'] < 2: - state['failures'] += 1 - raise Exception('faked exception') - return x - - messages = [Message('name1', 'data1')] - with mock.patch('httpx.Client.send', side_effect=side_effect, autospec=True): - channel.publish(messages=messages) - - assert state['failures'] == 2 - history = channel.history() - assert len(history.items) == 1 - client.close() - ably.close() - - # RSL1k5 - def test_idempotent_client_supplied_publish(self): - ably = self.get_ably_rest(idempotent_rest_publishing=True) - channel = ably.channels[self.get_channel_name()] - - messages = [Message('name1', 'data1', id='foobar')] - channel.publish(messages=messages) - channel.publish(messages=messages) - channel.publish(messages=messages) - history = channel.history() - assert len(history.items) == 1 - ably.close() diff --git a/test/ably/sync/rest/sync_restchannels_test.py b/test/ably/sync/rest/sync_restchannels_test.py deleted file mode 100644 index 88587313..00000000 --- a/test/ably/sync/rest/sync_restchannels_test.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections.abc import Iterable - -import pytest - -from ably.sync import AblyException -from ably.sync.rest.channel import ChannelSync, ChannelsSync, Presence -from ably.sync.util.crypto import generate_random_key - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import BaseAsyncTestCase - - -# makes no request, no need to use different protocols -class TestChannels(BaseAsyncTestCase): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def test_rest_channels_attr(self): - assert hasattr(self.ably, 'channels') - assert isinstance(self.ably.channels, ChannelsSync) - - def test_channels_get_returns_new_or_existing(self): - channel = self.ably.channels.get('new_channel') - assert isinstance(channel, ChannelSync) - channel_same = self.ably.channels.get('new_channel') - assert channel is channel_same - - def test_channels_get_returns_new_with_options(self): - key = generate_random_key() - channel = self.ably.channels.get('new_channel', cipher={'key': key}) - assert isinstance(channel, ChannelSync) - assert channel.cipher.secret_key is key - - def test_channels_get_updates_existing_with_options(self): - key = generate_random_key() - channel = self.ably.channels.get('new_channel', cipher={'key': key}) - assert channel.cipher is not None - - channel_same = self.ably.channels.get('new_channel', cipher=None) - assert channel is channel_same - assert channel.cipher is None - - def test_channels_get_doesnt_updates_existing_with_none_options(self): - key = generate_random_key() - channel = self.ably.channels.get('new_channel', cipher={'key': key}) - assert channel.cipher is not None - - channel_same = self.ably.channels.get('new_channel') - assert channel is channel_same - assert channel.cipher is not None - - def test_channels_in(self): - assert 'new_channel' not in self.ably.channels - self.ably.channels.get('new_channel') - new_channel_2 = self.ably.channels.get('new_channel_2') - assert 'new_channel' in self.ably.channels - assert new_channel_2 in self.ably.channels - - def test_channels_iteration(self): - channel_names = ['channel_{}'.format(i) for i in range(5)] - [self.ably.channels.get(name) for name in channel_names] - - assert isinstance(self.ably.channels, Iterable) - for name, channel in zip(channel_names, self.ably.channels): - assert isinstance(channel, ChannelSync) - assert name == channel.name - - # RSN4a, RSN4b - def test_channels_release(self): - self.ably.channels.get('new_channel') - self.ably.channels.release('new_channel') - self.ably.channels.release('new_channel') - - def test_channel_has_presence(self): - channel = self.ably.channels.get('new_channnel') - assert channel.presence - assert isinstance(channel.presence, Presence) - - def test_without_permissions(self): - key = self.test_vars["keys"][2] - ably = TestApp.get_ably_rest(key=key["key_str"]) - with pytest.raises(AblyException) as excinfo: - ably.channels['test_publish_without_permission'].publish('foo', 'woop') - - assert 'not permitted' in excinfo.value.message - ably.close() diff --git a/test/ably/sync/rest/sync_restchannelstatus_test.py b/test/ably/sync/rest/sync_restchannelstatus_test.py deleted file mode 100644 index 5d281221..00000000 --- a/test/ably/sync/rest/sync_restchannelstatus_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -class TestRestChannelStatus(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def test_channel_status(self): - channel_name = self.get_channel_name('test_channel_status') - channel = self.ably.channels[channel_name] - - channel_status = channel.status() - - assert channel_status is not None, "Expected non-None channel_status" - assert channel_name == channel_status.channel_id, "Expected channel name to match" - assert channel_status.status.is_active is True, "Expected is_active to be True" - assert isinstance(channel_status.status.occupancy.metrics.publishers, int) and\ - channel_status.status.occupancy.metrics.publishers >= 0,\ - "Expected publishers to be a non-negative int" - assert isinstance(channel_status.status.occupancy.metrics.connections, int) and\ - channel_status.status.occupancy.metrics.connections >= 0,\ - "Expected connections to be a non-negative int" - assert isinstance(channel_status.status.occupancy.metrics.subscribers, int) and\ - channel_status.status.occupancy.metrics.subscribers >= 0,\ - "Expected subscribers to be a non-negative int" - assert isinstance(channel_status.status.occupancy.metrics.presence_members, int) and\ - channel_status.status.occupancy.metrics.presence_members >= 0,\ - "Expected presence_members to be a non-negative int" - assert isinstance(channel_status.status.occupancy.metrics.presence_connections, int) and\ - channel_status.status.occupancy.metrics.presence_connections >= 0,\ - "Expected presence_connections to be a non-negative int" - assert isinstance(channel_status.status.occupancy.metrics.presence_subscribers, int) and\ - channel_status.status.occupancy.metrics.presence_subscribers >= 0,\ - "Expected presence_subscribers to be a non-negative int" diff --git a/test/ably/sync/rest/sync_restcrypto_test.py b/test/ably/sync/rest/sync_restcrypto_test.py deleted file mode 100644 index 3dd89bc2..00000000 --- a/test/ably/sync/rest/sync_restcrypto_test.py +++ /dev/null @@ -1,264 +0,0 @@ -# import json -# import os -# import logging -# import base64 -# -# import pytest -# -# from ably import AblyException -# from ably.types.message import Message -# from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params -# -# from Crypto import Random -# -# from test.ably.testapp import TestApp -# from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase -# -# log = logging.getLogger(__name__) -# -# -# class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): -# -# async def asyncSetUp(self): -# self.test_vars = await TestApp.get_test_vars() -# self.ably = await TestApp.get_ably_rest() -# self.ably2 = await TestApp.get_ably_rest() -# -# async def asyncTearDown(self): -# await self.ably.close() -# await self.ably2.close() -# -# def per_protocol_setup(self, use_binary_protocol): -# # This will be called every test that vary by protocol for each protocol -# self.ably.options.use_binary_protocol = use_binary_protocol -# self.ably2.options.use_binary_protocol = use_binary_protocol -# self.use_binary_protocol = use_binary_protocol -# -# @dont_vary_protocol -# def test_cbc_channel_cipher(self): -# key = ( -# b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' -# b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') -# -# iv = ( -# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' -# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') -# -# log.debug("KEY_LEN: %d" % len(key)) -# log.debug("IV_LEN: %d" % len(iv)) -# cipher = get_cipher({'key': key, 'iv': iv}) -# -# plaintext = b"The quick brown fox" -# expected_ciphertext = ( -# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' -# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' -# b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' -# b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' -# b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' -# b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') -# -# actual_ciphertext = cipher.encrypt(plaintext) -# -# assert expected_ciphertext == actual_ciphertext -# -# async def test_crypto_publish(self): -# channel_name = self.get_channel_name('persisted:crypto_publish_text') -# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# history = await publish0.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String)" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_publish_256(self): -# rndfile = Random.new() -# key = rndfile.read(32) -# channel_name = 'persisted:crypto_publish_text_256' -# channel_name += '_bin' if self.use_binary_protocol else '_text' -# -# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# history = await publish0.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String)" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_publish_key_mismatch(self): -# channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') -# -# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# with pytest.raises(AblyException) as excinfo: -# await rx_channel.history() -# -# message = excinfo.value.message -# assert 'invalid-padding' == message or "codec can't decode" in message -# -# async def test_crypto_send_unencrypted(self): -# channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') -# publish0 = self.ably.channels[channel_name] -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# history = await rx_channel.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_encrypted_unhandled(self): -# channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') -# key = b'0123456789abcdef' -# data = 'foobar' -# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) -# -# await publish0.publish("publish0", data) -# -# rx_channel = self.ably2.channels[channel_name] -# history = await rx_channel.history() -# message = history.items[0] -# cipher = get_cipher(get_default_params({'key': key})) -# assert cipher.decrypt(message.data).decode() == data -# assert message.encoding == 'utf-8/cipher+aes-128-cbc' -# -# @dont_vary_protocol -# def test_cipher_params(self): -# params = CipherParams(secret_key='0123456789abcdef') -# assert params.algorithm == 'AES' -# assert params.mode == 'CBC' -# assert params.key_length == 128 -# -# params = CipherParams(secret_key='0123456789abcdef' * 2) -# assert params.algorithm == 'AES' -# assert params.mode == 'CBC' -# assert params.key_length == 256 -# -# -# class AbstractTestCryptoWithFixture: -# -# @classmethod -# def setUpClass(cls): -# resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file -# with open(resources_path, 'r') as f: -# cls.fixture = json.loads(f.read()) -# cls.params = { -# 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), -# 'mode': cls.fixture['mode'], -# 'algorithm': cls.fixture['algorithm'], -# 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), -# } -# cls.cipher_params = CipherParams(**cls.params) -# cls.cipher = get_cipher(cls.cipher_params) -# cls.items = cls.fixture['items'] -# -# def get_encoded(self, encoded_item): -# if encoded_item.get('encoding') == 'base64': -# return base64.b64decode(encoded_item['data'].encode('ascii')) -# elif encoded_item.get('encoding') == 'json': -# return json.loads(encoded_item['data']) -# return encoded_item['data'] -# -# # TM3 -# def test_decode(self): -# for item in self.items: -# assert item['encoded']['name'] == item['encrypted']['name'] -# message = Message.from_encoded(item['encrypted'], self.cipher) -# assert message.encoding == '' -# expected_data = self.get_encoded(item['encoded']) -# assert expected_data == message.data -# -# # TM3 -# def test_decode_array(self): -# items_encrypted = [item['encrypted'] for item in self.items] -# messages = Message.from_encoded_array(items_encrypted, self.cipher) -# for i, message in enumerate(messages): -# assert message.encoding == '' -# expected_data = self.get_encoded(self.items[i]['encoded']) -# assert expected_data == message.data -# -# def test_encode(self): -# for item in self.items: -# # need to reset iv -# self.cipher_params = CipherParams(**self.params) -# self.cipher = get_cipher(self.cipher_params) -# data = self.get_encoded(item['encoded']) -# expected = item['encrypted'] -# message = Message(item['encoded']['name'], data) -# message.encrypt(self.cipher) -# as_dict = message.as_dict() -# assert as_dict['data'] == expected['data'] -# assert as_dict['encoding'] == expected['encoding'] -# -# -# class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): -# fixture_file = 'crypto-data-128.json' -# -# -# class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): -# fixture_file = 'crypto-data-256.json' diff --git a/test/ably/sync/rest/sync_resthttp_test.py b/test/ably/sync/rest/sync_resthttp_test.py deleted file mode 100644 index 0c00b55b..00000000 --- a/test/ably/sync/rest/sync_resthttp_test.py +++ /dev/null @@ -1,229 +0,0 @@ -import base64 -import re -import time - -import httpx -import mock -import pytest -from urllib.parse import urljoin - -import respx -from httpx import Response - -from ably.sync import AblyRestSync -from ably.sync.transport.defaults import Defaults -from ably.sync.types.options import Options -from ably.sync.util.exceptions import AblyException -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import BaseAsyncTestCase - - -class TestRestHttp(BaseAsyncTestCase): - def test_max_retry_attempts_and_timeouts_defaults(self): - ably = AblyRestSync(token="foo") - assert 'http_open_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS - assert 'http_request_timeout' in ably.http.CONNECTION_RETRY_DEFAULTS - - with mock.patch('httpx.Client.send', side_effect=httpx.RequestError('')) as send_mock: - with pytest.raises(httpx.RequestError): - ably.http.make_request('GET', '/', version=Defaults.protocol_version, skip_auth=True) - - assert send_mock.call_count == Defaults.http_max_retry_count - assert send_mock.call_args == mock.call(mock.ANY) - ably.close() - - def test_cumulative_timeout(self): - ably = AblyRestSync(token="foo") - assert 'http_max_retry_duration' in ably.http.CONNECTION_RETRY_DEFAULTS - - ably.options.http_max_retry_duration = 0.5 - - def sleep_and_raise(*args, **kwargs): - time.sleep(0.51) - raise httpx.TimeoutException('timeout') - - with mock.patch('httpx.Client.send', side_effect=sleep_and_raise) as send_mock: - with pytest.raises(httpx.TimeoutException): - ably.http.make_request('GET', '/', skip_auth=True) - - assert send_mock.call_count == 1 - ably.close() - - def test_host_fallback(self): - ably = AblyRestSync(token="foo") - - def make_url(host): - base_url = "%s://%s:%d" % (ably.http.preferred_scheme, - host, - ably.http.preferred_port) - return urljoin(base_url, '/') - - with mock.patch('httpx.Request', wraps=httpx.Request) as request_mock: - with mock.patch('httpx.Client.send', side_effect=httpx.RequestError('')) as send_mock: - with pytest.raises(httpx.RequestError): - ably.http.make_request('GET', '/', skip_auth=True) - - assert send_mock.call_count == Defaults.http_max_retry_count - - expected_urls_set = { - make_url(host) - for host in Options(http_max_retry_count=10).get_rest_hosts() - } - for ((_, url), _) in request_mock.call_args_list: - assert url in expected_urls_set - expected_urls_set.remove(url) - - expected_hosts_set = set(Options(http_max_retry_count=10).get_rest_hosts()) - for (prep_request_tuple, _) in send_mock.call_args_list: - assert prep_request_tuple[0].headers.get('host') in expected_hosts_set - expected_hosts_set.remove(prep_request_tuple[0].headers.get('host')) - ably.close() - - @respx.mock - def test_no_host_fallback_nor_retries_if_custom_host(self): - custom_host = 'example.org' - ably = AblyRestSync(token="foo", rest_host=custom_host) - - mock_route = respx.get("https://example.org").mock(side_effect=httpx.RequestError('')) - - with pytest.raises(httpx.RequestError): - ably.http.make_request('GET', '/', skip_auth=True) - - assert mock_route.call_count == 1 - assert respx.calls.call_count == 1 - - ably.close() - - # RSC15f - def test_cached_fallback(self): - timeout = 2000 - ably = TestApp.get_ably_rest(fallback_retry_timeout=timeout) - host = ably.options.get_rest_host() - - state = {'errors': 0} - client = httpx.Client(http2=True) - send = client.send - - def side_effect(*args, **kwargs): - if args[1].url.host == host: - state['errors'] += 1 - raise RuntimeError - return send(args[1]) - - with mock.patch('httpx.Client.send', side_effect=side_effect, autospec=True): - # The main host is called and there's an error - ably.time() - assert state['errors'] == 1 - - # The cached host is used: no error - ably.time() - ably.time() - ably.time() - assert state['errors'] == 1 - - # The cached host has expired, we've an error again - time.sleep(timeout / 1000.0) - ably.time() - assert state['errors'] == 2 - - client.close() - ably.close() - - @respx.mock - def test_no_retry_if_not_500_to_599_http_code(self): - default_host = Options().get_rest_host() - ably = AblyRestSync(token="foo") - - default_url = "%s://%s:%d/" % ( - ably.http.preferred_scheme, - default_host, - ably.http.preferred_port) - - mock_response = httpx.Response(600, json={'message': "", 'status_code': 600, 'code': 50500}) - - mock_route = respx.get(default_url).mock(return_value=mock_response) - - with pytest.raises(AblyException): - ably.http.make_request('GET', '/', skip_auth=True) - - assert mock_route.call_count == 1 - assert respx.calls.call_count == 1 - - ably.close() - - def test_500_errors(self): - """ - Raise error if all the servers reply with a 5xx error. - https://github.com/ably/ably-python/issues/160 - """ - - ably = AblyRestSync(token="foo") - - def raise_ably_exception(*args, **kwargs): - raise AblyException(message="", status_code=500, code=50000) - - with mock.patch('httpx.Request', wraps=httpx.Request): - with mock.patch('ably.sync.util.exceptions.AblyException.raise_for_response', - side_effect=raise_ably_exception) as send_mock: - with pytest.raises(AblyException): - ably.http.make_request('GET', '/', skip_auth=True) - - assert send_mock.call_count == 3 - ably.close() - - def test_custom_http_timeouts(self): - ably = AblyRestSync( - token="foo", http_request_timeout=30, http_open_timeout=8, - http_max_retry_count=6, http_max_retry_duration=20) - - assert ably.http.http_request_timeout == 30 - assert ably.http.http_open_timeout == 8 - assert ably.http.http_max_retry_count == 6 - assert ably.http.http_max_retry_duration == 20 - - # RSC7a, RSC7b - def test_request_headers(self): - ably = TestApp.get_ably_rest() - r = ably.http.make_request('HEAD', '/time', skip_auth=True) - - # API - assert 'X-Ably-Version' in r.request.headers - assert r.request.headers['X-Ably-Version'] == '3' - - # Agent - assert 'Ably-Agent' in r.request.headers - expr = r"^ably-python\/\d.\d.\d(-beta\.\d)? python\/\d.\d+.\d+$" - assert re.search(expr, r.request.headers['Ably-Agent']) - ably.close() - - # RSC7c - def test_add_request_ids(self): - # With request id - ably = TestApp.get_ably_rest(add_request_ids=True) - r = ably.http.make_request('HEAD', '/time', skip_auth=True) - assert 'request_id' in r.request.url.params - request_id1 = r.request.url.params['request_id'] - assert len(base64.urlsafe_b64decode(request_id1)) == 12 - - # With request id and new request - r = ably.http.make_request('HEAD', '/time', skip_auth=True) - assert 'request_id' in r.request.url.params - request_id2 = r.request.url.params['request_id'] - assert len(base64.urlsafe_b64decode(request_id2)) == 12 - assert request_id1 != request_id2 - ably.close() - - # With request id and new request - ably = TestApp.get_ably_rest() - r = ably.http.make_request('HEAD', '/time', skip_auth=True) - assert 'request_id' not in r.request.url.params - ably.close() - - def test_request_over_http2(self): - url = 'https://www.example.com' - respx.get(url).mock(return_value=Response(status_code=200)) - - ably = TestApp.get_ably_rest(rest_host=url) - r = ably.http.make_request('GET', url, skip_auth=True) - assert r.http_version == 'HTTP/2' - ably.close() diff --git a/test/ably/sync/rest/sync_restinit_test.py b/test/ably/sync/rest/sync_restinit_test.py deleted file mode 100644 index 99837890..00000000 --- a/test/ably/sync/rest/sync_restinit_test.py +++ /dev/null @@ -1,227 +0,0 @@ -from mock import patch -import pytest -from httpx import Client - -from ably.sync import AblyRestSync -from ably.sync import AblyException -from ably.sync.transport.defaults import Defaults -from ably.sync.types.tokendetails import TokenDetails - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - - -class TestRestInit(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - - @dont_vary_protocol - def test_key_only(self): - ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"]) - assert ably.options.key_name == self.test_vars["keys"][0]["key_name"], "Key name does not match" - assert ably.options.key_secret == self.test_vars["keys"][0]["key_secret"], "Key secret does not match" - - def per_protocol_setup(self, use_binary_protocol): - self.use_binary_protocol = use_binary_protocol - - @dont_vary_protocol - def test_with_token(self): - ably = AblyRestSync(token="foo") - assert ably.options.auth_token == "foo", "Token not set at options" - - @dont_vary_protocol - def test_with_token_details(self): - td = TokenDetails() - ably = AblyRestSync(token_details=td) - assert ably.options.token_details is td - - @dont_vary_protocol - def test_with_options_token_callback(self): - def token_callback(**params): - return "this_is_not_really_a_token_request" - AblyRestSync(auth_callback=token_callback) - - @dont_vary_protocol - def test_ambiguous_key_raises_value_error(self): - with pytest.raises(ValueError, match="mutually exclusive"): - AblyRestSync(key=self.test_vars["keys"][0]["key_str"], key_name='x') - with pytest.raises(ValueError, match="mutually exclusive"): - AblyRestSync(key=self.test_vars["keys"][0]["key_str"], key_secret='x') - - @dont_vary_protocol - def test_with_key_name_or_secret_only(self): - with pytest.raises(ValueError, match="key is missing"): - AblyRestSync(key_name='x') - with pytest.raises(ValueError, match="key is missing"): - AblyRestSync(key_secret='x') - - @dont_vary_protocol - def test_with_key_name_and_secret(self): - ably = AblyRestSync(key_name="foo", key_secret="bar") - assert ably.options.key_name == "foo", "Key name does not match" - assert ably.options.key_secret == "bar", "Key secret does not match" - - @dont_vary_protocol - def test_with_options_auth_url(self): - AblyRestSync(auth_url='not_really_an_url') - - # RSC11 - @dont_vary_protocol - def test_rest_host_and_environment(self): - # rest host - ably = AblyRestSync(token='foo', rest_host="some.other.host") - assert "some.other.host" == ably.options.rest_host, "Unexpected host mismatch" - - # environment: production - ably = AblyRestSync(token='foo', environment="production") - host = ably.options.get_rest_host() - assert "rest.ably.io" == host, "Unexpected host mismatch %s" % host - - # environment: other - ably = AblyRestSync(token='foo', environment="sandbox") - host = ably.options.get_rest_host() - assert "sandbox-rest.ably.io" == host, "Unexpected host mismatch %s" % host - - # both, as per #TO3k2 - with pytest.raises(ValueError): - ably = AblyRestSync(token='foo', rest_host="some.other.host", - environment="some.other.environment") - - # RSC15 - @dont_vary_protocol - def test_fallback_hosts(self): - # Specify the fallback_hosts (RSC15a) - fallback_hosts = [ - ['fallback1.com', 'fallback2.com'], - [], - ] - - # Fallback hosts specified (RSC15g1) - for aux in fallback_hosts: - ably = AblyRestSync(token='foo', fallback_hosts=aux) - assert sorted(aux) == sorted(ably.options.get_fallback_rest_hosts()) - - # Specify environment (RSC15g2) - ably = AblyRestSync(token='foo', environment='sandbox', http_max_retry_count=10) - assert sorted(Defaults.get_environment_fallback_hosts('sandbox')) == sorted( - ably.options.get_fallback_rest_hosts()) - - # Fallback hosts and environment not specified (RSC15g3) - ably = AblyRestSync(token='foo', http_max_retry_count=10) - assert sorted(Defaults.fallback_hosts) == sorted(ably.options.get_fallback_rest_hosts()) - - # RSC15f - ably = AblyRestSync(token='foo') - assert 600000 == ably.options.fallback_retry_timeout - ably = AblyRestSync(token='foo', fallback_retry_timeout=1000) - assert 1000 == ably.options.fallback_retry_timeout - - @dont_vary_protocol - def test_specified_realtime_host(self): - ably = AblyRestSync(token='foo', realtime_host="some.other.host") - assert "some.other.host" == ably.options.realtime_host, "Unexpected host mismatch" - - @dont_vary_protocol - def test_specified_port(self): - ably = AblyRestSync(token='foo', port=9998, tls_port=9999) - assert 9999 == Defaults.get_port(ably.options),\ - "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port - - @dont_vary_protocol - def test_specified_non_tls_port(self): - ably = AblyRestSync(token='foo', port=9998, tls=False) - assert 9998 == Defaults.get_port(ably.options),\ - "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port - - @dont_vary_protocol - def test_specified_tls_port(self): - ably = AblyRestSync(token='foo', tls_port=9999, tls=True) - assert 9999 == Defaults.get_port(ably.options),\ - "Unexpected port mismatch. Expected: 9999. Actual: %d" % ably.options.tls_port - - @dont_vary_protocol - def test_tls_defaults_to_true(self): - ably = AblyRestSync(token='foo') - assert ably.options.tls, "Expected encryption to default to true" - assert Defaults.tls_port == Defaults.get_port(ably.options), "Unexpected port mismatch" - - @dont_vary_protocol - def test_tls_can_be_disabled(self): - ably = AblyRestSync(token='foo', tls=False) - assert not ably.options.tls, "Expected encryption to be False" - assert Defaults.port == Defaults.get_port(ably.options), "Unexpected port mismatch" - - @dont_vary_protocol - def test_with_no_params(self): - with pytest.raises(ValueError): - AblyRestSync() - - @dont_vary_protocol - def test_with_no_auth_params(self): - with pytest.raises(ValueError): - AblyRestSync(port=111) - - # RSA10k - def test_query_time_param(self): - ably = TestApp.get_ably_rest(query_time=True, - use_binary_protocol=self.use_binary_protocol) - - timestamp = ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=ably.time) as server_time,\ - patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: - ably.auth.request_token() - assert local_time.call_count == 1 - assert server_time.call_count == 1 - ably.auth.request_token() - assert local_time.call_count == 2 - assert server_time.call_count == 1 - - ably.close() - - @dont_vary_protocol - def test_requests_over_https_production(self): - ably = AblyRestSync(token='token') - assert 'https://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) - assert ably.http.preferred_port == 443 - - @dont_vary_protocol - def test_requests_over_http_production(self): - ably = AblyRestSync(token='token', tls=False) - assert 'http://rest.ably.io' == '{0}://{1}'.format(ably.http.preferred_scheme, ably.http.preferred_host) - assert ably.http.preferred_port == 80 - - @dont_vary_protocol - def test_request_basic_auth_over_http_fails(self): - ably = AblyRestSync(key_secret='foo', key_name='bar', tls=False) - - with pytest.raises(AblyException) as excinfo: - ably.http.get('/time', skip_auth=False) - - assert 401 == excinfo.value.status_code - assert 40103 == excinfo.value.code - assert 'Cannot use Basic Auth over non-TLS connections' == excinfo.value.message - - @dont_vary_protocol - def test_environment(self): - ably = AblyRestSync(token='token', environment='custom') - with patch.object(Client, 'send', wraps=ably.http._HttpSync__client.send) as get_mock: - try: - ably.time() - except AblyException: - pass - request = get_mock.call_args_list[0][0][0] - assert request.url == 'https://custom-rest.ably.io:443/time' - - ably.close() - - @dont_vary_protocol - def test_accepts_custom_http_timeouts(self): - ably = AblyRestSync( - token="foo", http_request_timeout=30, http_open_timeout=8, - http_max_retry_count=6, http_max_retry_duration=20) - - assert ably.options.http_request_timeout == 30 - assert ably.options.http_open_timeout == 8 - assert ably.options.http_max_retry_count == 6 - assert ably.options.http_max_retry_duration == 20 diff --git a/test/ably/sync/rest/sync_restpaginatedresult_test.py b/test/ably/sync/rest/sync_restpaginatedresult_test.py deleted file mode 100644 index 312ce100..00000000 --- a/test/ably/sync/rest/sync_restpaginatedresult_test.py +++ /dev/null @@ -1,91 +0,0 @@ -import respx -from httpx import Response - -from ably.sync.http.paginatedresult import PaginatedResultSync - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import BaseAsyncTestCase - - -class TestPaginatedResult(BaseAsyncTestCase): - - def get_response_callback(self, headers, body, status): - def callback(request): - res = request.url.params.get('page') - if res: - return Response( - status_code=status, - headers=headers, - content='[{"page": %i}]' % int(res) - ) - - return Response( - status_code=status, - headers=headers, - content=body - ) - - return callback - - def setUp(self): - self.ably = TestApp.get_ably_rest(use_binary_protocol=False) - # Mocked responses - # without specific headers - self.mocked_api = respx.mock(base_url='http://rest.ably.io') - self.ch1_route = self.mocked_api.get('/channels/channel_name/ch1') - self.ch1_route.return_value = Response( - headers={'content-type': 'application/json'}, - status_code=200, - content='[{"id": 0}, {"id": 1}]', - ) - # with headers - self.ch2_route = self.mocked_api.get('/channels/channel_name/ch2') - self.ch2_route.side_effect = self.get_response_callback( - headers={ - 'content-type': 'application/json', - 'link': - '; rel="first",' - ' ; rel="next"' - }, - body='[{"id": 0}, {"id": 1}]', - status=200 - ) - # start intercepting requests - self.mocked_api.start() - - self.paginated_result = PaginatedResultSync.paginated_query( - self.ably.http, - url='http://rest.ably.io/channels/channel_name/ch1', - response_processor=lambda response: response.to_native()) - self.paginated_result_with_headers = PaginatedResultSync.paginated_query( - self.ably.http, - url='http://rest.ably.io/channels/channel_name/ch2', - response_processor=lambda response: response.to_native()) - - def tearDown(self): - self.mocked_api.stop() - self.mocked_api.reset() - self.ably.close() - - def test_items(self): - assert len(self.paginated_result.items) == 2 - - def test_with_no_headers(self): - assert self.paginated_result.first() is None - assert self.paginated_result.next() is None - assert self.paginated_result.is_last() - - def test_with_next(self): - pag = self.paginated_result_with_headers - assert pag.has_next() - assert not pag.is_last() - - def test_first(self): - pag = self.paginated_result_with_headers - pag = pag.first() - assert pag.items[0]['page'] == 1 - - def test_next(self): - pag = self.paginated_result_with_headers - pag = pag.next() - assert pag.items[0]['page'] == 2 diff --git a/test/ably/sync/rest/sync_restpresence_test.py b/test/ably/sync/rest/sync_restpresence_test.py deleted file mode 100644 index 2789ccb0..00000000 --- a/test/ably/sync/rest/sync_restpresence_test.py +++ /dev/null @@ -1,213 +0,0 @@ -from datetime import datetime, timedelta - -import pytest -import respx - -from ably.sync.http.paginatedresult import PaginatedResultSync -from ably.sync.types.presence import PresenceMessage - -from test.ably.sync.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseAsyncTestCase -from test.ably.sync.testapp import TestApp - - -class TestPresence(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.test_vars = TestApp.get_test_vars() - self.ably = TestApp.get_ably_rest() - self.channel = self.ably.channels.get('persisted:presence_fixtures') - self.ably.options.use_binary_protocol = True - - def tearDown(self): - self.ably.channels.release('persisted:presence_fixtures') - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - def test_channel_presence_get(self): - presence_page = self.channel.presence.get() - assert isinstance(presence_page, PaginatedResultSync) - assert len(presence_page.items) == 6 - member = presence_page.items[0] - assert isinstance(member, PresenceMessage) - assert member.action - assert member.id - assert member.client_id - assert member.data - assert member.connection_id - assert member.timestamp - - def test_channel_presence_history(self): - presence_history = self.channel.presence.history() - assert isinstance(presence_history, PaginatedResultSync) - assert len(presence_history.items) == 6 - member = presence_history.items[0] - assert isinstance(member, PresenceMessage) - assert member.action - assert member.id - assert member.client_id - assert member.data - assert member.connection_id - assert member.timestamp - assert member.encoding - - def test_presence_get_encoded(self): - presence_history = self.channel.presence.history() - assert presence_history.items[-1].data == "true" - assert presence_history.items[-2].data == "24" - assert presence_history.items[-3].data == "This is a string clientData payload" - # this one doesn't have encoding field - assert presence_history.items[-4].data == '{ "test": "This is a JSONObject clientData payload"}' - assert presence_history.items[-5].data == {"example": {"json": "Object"}} - - def test_timestamp_is_datetime(self): - presence_page = self.channel.presence.get() - member = presence_page.items[0] - assert isinstance(member.timestamp, datetime) - - def test_presence_message_has_correct_member_key(self): - presence_page = self.channel.presence.get() - member = presence_page.items[0] - - assert member.member_key == "%s:%s" % (member.connection_id, member.client_id) - - def presence_mock_url(self): - kwargs = { - 'scheme': 'https' if self.test_vars['tls'] else 'http', - 'host': self.test_vars['host'] - } - port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] - if port == 80: - kwargs['port_sufix'] = '' - else: - kwargs['port_sufix'] = ':' + str(port) - url = '{scheme}://{host}{port_sufix}/channels/persisted%3Apresence_fixtures/presence' - return url.format(**kwargs) - - def history_mock_url(self): - kwargs = { - 'scheme': 'https' if self.test_vars['tls'] else 'http', - 'host': self.test_vars['host'] - } - port = self.test_vars['tls_port'] if self.test_vars.get('tls') else kwargs['port'] - if port == 80: - kwargs['port_sufix'] = '' - else: - kwargs['port_sufix'] = ':' + str(port) - url = '{scheme}://{host}{port_sufix}/channels/persisted%3Apresence_fixtures/presence/history' - return url.format(**kwargs) - - @dont_vary_protocol - @respx.mock - def test_get_presence_default_limit(self): - url = self.presence_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.get() - assert 'limit' not in respx.calls[0].request.url.params.keys() - - @dont_vary_protocol - @respx.mock - def test_get_presence_with_limit(self): - url = self.presence_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.get(300) - assert '300' == respx.calls[0].request.url.params.get('limit') - - @dont_vary_protocol - @respx.mock - def test_get_presence_max_limit_is_1000(self): - url = self.presence_mock_url() - self.respx_add_empty_msg_pack(url) - with pytest.raises(ValueError): - self.channel.presence.get(5000) - - @dont_vary_protocol - @respx.mock - def test_history_default_limit(self): - url = self.history_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.history() - assert 'limit' not in respx.calls[0].request.url.params.keys() - - @dont_vary_protocol - @respx.mock - def test_history_with_limit(self): - url = self.history_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.history(300) - assert '300' == respx.calls[0].request.url.params.get('limit') - - @dont_vary_protocol - @respx.mock - def test_history_with_direction(self): - url = self.history_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.history(direction='backwards') - assert 'backwards' == respx.calls[0].request.url.params.get('direction') - - @dont_vary_protocol - @respx.mock - def test_history_max_limit_is_1000(self): - url = self.history_mock_url() - self.respx_add_empty_msg_pack(url) - with pytest.raises(ValueError): - self.channel.presence.history(5000) - - @dont_vary_protocol - @respx.mock - def test_with_milisecond_start_end(self): - url = self.history_mock_url() - self.respx_add_empty_msg_pack(url) - self.channel.presence.history(start=100000, end=100001) - assert '100000' == respx.calls[0].request.url.params.get('start') - assert '100001' == respx.calls[0].request.url.params.get('end') - - @dont_vary_protocol - @respx.mock - def test_with_timedate_startend(self): - url = self.history_mock_url() - start = datetime(2015, 8, 15, 17, 11, 44, 706539) - start_ms = 1439658704706 - end = start + timedelta(hours=1) - end_ms = start_ms + (1000 * 60 * 60) - self.respx_add_empty_msg_pack(url) - self.channel.presence.history(start=start, end=end) - assert str(start_ms) in respx.calls[0].request.url.params.get('start') - assert str(end_ms) in respx.calls[0].request.url.params.get('end') - - @dont_vary_protocol - @respx.mock - def test_with_start_gt_end(self): - url = self.history_mock_url() - end = datetime(2015, 8, 15, 17, 11, 44, 706539) - start = end + timedelta(hours=1) - self.respx_add_empty_msg_pack(url) - with pytest.raises(ValueError, match="'end' parameter has to be greater than or equal to 'start'"): - self.channel.presence.history(start=start, end=end) - - -class TestPresenceCrypt(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - key = b'0123456789abcdef' - self.channel = self.ably.channels.get('persisted:presence_fixtures', cipher={'key': key}) - - def tearDown(self): - self.ably.channels.release('persisted:presence_fixtures') - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - def test_presence_history_encrypted(self): - presence_history = self.channel.presence.history() - assert presence_history.items[0].data == {'foo': 'bar'} - - def test_presence_get_encrypted(self): - messages = self.channel.presence.get() - messages = (msg for msg in messages.items if msg.client_id == 'client_encoded') - message = next(messages) - - assert message.data == {'foo': 'bar'} diff --git a/test/ably/sync/rest/sync_restpush_test.py b/test/ably/sync/rest/sync_restpush_test.py deleted file mode 100644 index d8114c32..00000000 --- a/test/ably/sync/rest/sync_restpush_test.py +++ /dev/null @@ -1,398 +0,0 @@ -import itertools -import random -import string -import time - -import pytest - -from ably.sync import AblyException, AblyAuthException -from ably.sync import DeviceDetails, PushChannelSubscription -from ably.sync.http.paginatedresult import PaginatedResultSync - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, BaseAsyncTestCase -from test.ably.sync.utils import new_dict, random_string, get_random_key - - -DEVICE_TOKEN = '740f4707bebcf74f9b7c25d48e3358945f6aa01da5ddb387462c7eaf61bb78ad' - - -class TestPush(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - - # Register several devices for later use - self.devices = {} - for i in range(10): - self.save_device() - - # Register several subscriptions for later use - self.channels = {'canpublish:test1': [], 'canpublish:test2': [], 'canpublish:test3': []} - for key, channel in zip(self.devices, itertools.cycle(self.channels)): - device = self.devices[key] - self.save_subscription(channel, device_id=device.id) - assert len(list(itertools.chain(*self.channels.values()))) == len(self.devices) - - def tearDown(self): - for key, channel in zip(self.devices, itertools.cycle(self.channels)): - device = self.devices[key] - self.remove_subscription(channel, device_id=device.id) - self.ably.push.admin.device_registrations.remove(device_id=device.id) - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - def get_client_id(self): - return random_string(12) - - def get_device_id(self): - return random_string(26, string.ascii_uppercase + string.digits) - - def gen_device_data(self, data=None, **kw): - if data is None: - data = { - 'id': self.get_device_id(), - 'clientId': self.get_client_id(), - 'platform': random.choice(['android', 'ios']), - 'formFactor': 'phone', - 'deviceSecret': 'test-secret', - 'push': { - 'recipient': { - 'transportType': 'apns', - 'deviceToken': DEVICE_TOKEN, - } - }, - } - else: - data = data.copy() - - data.update(kw) - return data - - def save_device(self, data=None, **kw): - """ - Helper method to register a device, to not have this code repeated - everywhere. Returns the input dict that was sent to Ably, and the - device details returned by Ably. - """ - data = self.gen_device_data(data, **kw) - device = self.ably.push.admin.device_registrations.save(data) - self.devices[device.id] = device - return device - - def remove_device(self, device_id): - result = self.ably.push.admin.device_registrations.remove(device_id) - self.devices.pop(device_id, None) - return result - - def remove_device_where(self, **kw): - remove_where = self.ably.push.admin.device_registrations.remove_where - result = remove_where(**kw) - - aux = {'deviceId': 'id', 'clientId': 'client_id'} - for device in list(self.devices.values()): - for key, value in kw.items(): - key = aux[key] - if getattr(device, key) == value: - del self.devices[device.id] - - return result - - def get_device(self): - key = get_random_key(self.devices) - return self.devices[key] - - def get_channel(self): - key = get_random_key(self.channels) - return key, self.channels[key] - - def save_subscription(self, channel, **kw): - """ - Helper method to register a device, to not have this code repeated - everywhere. Returns the input dict that was sent to Ably, and the - device details returned by Ably. - """ - subscription = PushChannelSubscription(channel, **kw) - subscription = self.ably.push.admin.channel_subscriptions.save(subscription) - self.channels.setdefault(channel, []).append(subscription) - return subscription - - def remove_subscription(self, channel, **kw): - subscription = PushChannelSubscription(channel, **kw) - subscription = self.ably.push.admin.channel_subscriptions.remove(subscription) - return subscription - - # RSH1a - def test_admin_publish(self): - recipient = {'clientId': 'ablyChannel'} - data = { - 'data': {'foo': 'bar'}, - } - - publish = self.ably.push.admin.publish - with pytest.raises(TypeError): - publish('ablyChannel', data) - with pytest.raises(TypeError): - publish(recipient, 25) - with pytest.raises(ValueError): - publish({}, data) - with pytest.raises(ValueError): - publish(recipient, {}) - - with pytest.raises(AblyException): - publish(recipient, {'xxx': 5}) - - assert publish(recipient, data) is None - - # RSH1b1 - def test_admin_device_registrations_get(self): - get = self.ably.push.admin.device_registrations.get - - # Not found - with pytest.raises(AblyException): - get('not-found') - - # Found - device = self.get_device() - device_details = get(device.id) - assert device_details.id == device.id - assert device_details.platform == device.platform - assert device_details.form_factor == device.form_factor - - # RSH1b2 - def test_admin_device_registrations_list(self): - list_devices = self.ably.push.admin.device_registrations.list - - list_response = list_devices() - assert type(list_response) is PaginatedResultSync - assert type(list_response.items) is list - assert type(list_response.items[0]) is DeviceDetails - - # limit - list_response = list_devices(limit=5000) - assert len(list_response.items) == len(self.devices) - list_response = list_devices(limit=2) - assert len(list_response.items) == 2 - - # Filter by device id - device = self.get_device() - list_response = list_devices(deviceId=device.id) - assert len(list_response.items) == 1 - list_response = list_devices(deviceId=self.get_device_id()) - assert len(list_response.items) == 0 - - # Filter by client id - list_response = list_devices(clientId=device.client_id) - assert len(list_response.items) == 1 - list_response = list_devices(clientId=self.get_client_id()) - assert len(list_response.items) == 0 - - # RSH1b3 - def test_admin_device_registrations_save(self): - # Create - data = self.gen_device_data() - device = self.save_device(data) - assert type(device) is DeviceDetails - - # Update - self.save_device(data, formFactor='tablet') - - # Invalid values - with pytest.raises(ValueError): - push = {'recipient': new_dict(data['push']['recipient'], transportType='xyz')} - self.save_device(data, push=push) - with pytest.raises(ValueError): - self.save_device(data, platform='native') - with pytest.raises(ValueError): - self.save_device(data, formFactor='fridge') - - # Fail - with pytest.raises(AblyException): - self.save_device(data, push={'color': 'red'}) - - # RSH1b4 - def test_admin_device_registrations_remove(self): - get = self.ably.push.admin.device_registrations.get - - device = self.get_device() - - # Remove - get_response = get(device.id) - assert get_response.id == device.id # Exists - remove_device_response = self.remove_device(device.id) - assert remove_device_response.status_code == 204 - with pytest.raises(AblyException): # Doesn't exist - get(device.id) - - # Remove again, it doesn't fail - remove_device_response = self.remove_device(device.id) - assert remove_device_response.status_code == 204 - - # RSH1b5 - def test_admin_device_registrations_remove_where(self): - get = self.ably.push.admin.device_registrations.get - - # Remove by device id - device = self.get_device() - foo_device = get(device.id) - assert foo_device.id == device.id # Exists - remove_foo_device_response = self.remove_device_where(deviceId=device.id) - assert remove_foo_device_response.status_code == 204 - with pytest.raises(AblyException): # Doesn't exist - get(device.id) - - # Remove by client id - device = self.get_device() - boo_device = get(device.id) - assert boo_device.id == device.id # Exists - remove_boo_device_response = self.remove_device_where(clientId=device.client_id) - assert remove_boo_device_response.status_code == 204 - # Doesn't exist (Deletion is async: wait up to a few seconds before giving up) - with pytest.raises(AblyException): - for i in range(5): - time.sleep(1) - get(device.id) - - # Remove with no matching params - remove_boo_device_response = self.remove_device_where(clientId=device.client_id) - assert remove_boo_device_response.status_code == 204 - - # # RSH1c1 - def test_admin_channel_subscriptions_list(self): - list_ = self.ably.push.admin.channel_subscriptions.list - - channel, subscriptions = self.get_channel() - - list_response = list_(channel=channel) - - assert type(list_response) is PaginatedResultSync - assert type(list_response.items) is list - assert type(list_response.items[0]) is PushChannelSubscription - - # limit - list_response = list_(channel=channel, limit=2) - assert len(list_response.items) == 2 - - list_response = list_(channel=channel, limit=5000) - assert len(list_response.items) == len(subscriptions) - - # Filter by device id - device_id = subscriptions[0].device_id - list_response = list_(channel=channel, deviceId=device_id) - assert len(list_response.items) == 1 - assert list_response.items[0].device_id == device_id - assert list_response.items[0].channel == channel - list_response = list_(channel=channel, deviceId=self.get_device_id()) - assert len(list_response.items) == 0 - - # Filter by client id - device = self.get_device() - list_response = list_(channel=channel, clientId=device.client_id) - assert len(list_response.items) == 0 - - # RSH1c2 - def test_admin_channels_list(self): - list_ = self.ably.push.admin.channel_subscriptions.list_channels - - list_response = list_() - assert type(list_response) is PaginatedResultSync - assert type(list_response.items) is list - assert type(list_response.items[0]) is str - - # limit - list_response = list_(limit=5000) - assert len(list_response.items) == len(self.channels) - list_response = list_(limit=1) - assert len(list_response.items) == 1 - - # RSH1c3 - def test_admin_channel_subscriptions_save(self): - save = self.ably.push.admin.channel_subscriptions.save - - # Subscribe - device = self.get_device() - channel = 'canpublish:testsave' - subscription = self.save_subscription(channel, device_id=device.id) - assert type(subscription) is PushChannelSubscription - assert subscription.channel == channel - assert subscription.device_id == device.id - assert subscription.client_id is None - - # Failures - client_id = self.get_client_id() - with pytest.raises(ValueError): - PushChannelSubscription(channel, device_id=device.id, client_id=client_id) - - subscription = PushChannelSubscription('notallowed', device_id=device.id) - with pytest.raises(AblyAuthException): - save(subscription) - - subscription = PushChannelSubscription(channel, device_id='notregistered') - with pytest.raises(AblyException): - save(subscription) - - # RSH1c4 - def test_admin_channel_subscriptions_remove(self): - save = self.ably.push.admin.channel_subscriptions.save - remove = self.ably.push.admin.channel_subscriptions.remove - list_ = self.ably.push.admin.channel_subscriptions.list - - channel = 'canpublish:testremove' - - # Subscribe device - device = self.get_device() - subscription = save(PushChannelSubscription(channel, device_id=device.id)) - list_response = list_(channel=channel) - assert device.id in (x.device_id for x in list_response.items) - remove_response = remove(subscription) - assert remove_response.status_code == 204 - list_response = list_(channel=channel) - assert device.id not in (x.device_id for x in list_response.items) - - # Subscribe client - client_id = self.get_client_id() - subscription = save(PushChannelSubscription(channel, client_id=client_id)) - list_response = list_(channel=channel) - assert client_id in (x.client_id for x in list_response.items) - remove_response = remove(subscription) - assert remove_response.status_code == 204 - list_response = list_(channel=channel) - assert client_id not in (x.client_id for x in list_response.items) - - # Remove again, it doesn't fail - remove_response = remove(subscription) - assert remove_response.status_code == 204 - - # RSH1c5 - def test_admin_channel_subscriptions_remove_where(self): - save = self.ably.push.admin.channel_subscriptions.save - remove = self.ably.push.admin.channel_subscriptions.remove_where - list_ = self.ably.push.admin.channel_subscriptions.list - - channel = 'canpublish:testremovewhere' - - # Subscribe device - device = self.get_device() - save(PushChannelSubscription(channel, device_id=device.id)) - list_response = list_(channel=channel) - assert device.id in (x.device_id for x in list_response.items) - remove_response = remove(channel=channel, device_id=device.id) - assert remove_response.status_code == 204 - list_response = list_(channel=channel) - assert device.id not in (x.device_id for x in list_response.items) - - # Subscribe client - client_id = self.get_client_id() - save(PushChannelSubscription(channel, client_id=client_id)) - list_response = list_(channel=channel) - assert client_id in (x.client_id for x in list_response.items) - remove_response = remove(channel=channel, client_id=client_id) - assert remove_response.status_code == 204 - list_response = list_(channel=channel) - assert client_id not in (x.client_id for x in list_response.items) - - # Remove again, it doesn't fail - remove_response = remove(channel=channel, client_id=client_id) - assert remove_response.status_code == 204 diff --git a/test/ably/sync/rest/sync_restrequest_test.py b/test/ably/sync/rest/sync_restrequest_test.py deleted file mode 100644 index 8c090ac7..00000000 --- a/test/ably/sync/rest/sync_restrequest_test.py +++ /dev/null @@ -1,132 +0,0 @@ -import httpx -import pytest -import respx - -from ably.sync import AblyRestSync -from ably.sync.http.paginatedresult import HttpPaginatedResponseSync -from ably.sync.transport.defaults import Defaults -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import BaseAsyncTestCase -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol - - -# RSC19 -class TestRestRequest(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.test_vars = TestApp.get_test_vars() - - # Populate the channel (using the new api) - self.channel = self.get_channel_name() - self.path = '/channels/%s/messages' % self.channel - for i in range(20): - body = {'name': 'event%s' % i, 'data': 'lorem ipsum %s' % i} - self.ably.request('POST', self.path, body=body, version=Defaults.protocol_version) - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def test_post(self): - body = {'name': 'test-post', 'data': 'lorem ipsum'} - result = self.ably.request('POST', self.path, body=body, version=Defaults.protocol_version) - - assert isinstance(result, HttpPaginatedResponseSync) # RSC19d - # HP3 - assert type(result.items) is list - assert len(result.items) == 1 - assert result.items[0]['channel'] == self.channel - assert 'messageId' in result.items[0] - - def test_get(self): - params = {'limit': 10, 'direction': 'forwards'} - result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) - - assert isinstance(result, HttpPaginatedResponseSync) # RSC19d - - # HP2 - assert isinstance(result.next(), HttpPaginatedResponseSync) - assert isinstance(result.first(), HttpPaginatedResponseSync) - - # HP3 - assert isinstance(result.items, list) - item = result.items[0] - assert isinstance(item, dict) - assert 'timestamp' in item - assert 'id' in item - assert item['name'] == 'event0' - assert item['data'] == 'lorem ipsum 0' - - assert result.status_code == 200 # HP4 - assert result.success is True # HP5 - assert result.error_code is None # HP6 - assert result.error_message is None # HP7 - assert isinstance(result.headers, list) # HP7 - - @dont_vary_protocol - def test_not_found(self): - result = self.ably.request('GET', '/not-found', version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponseSync) # RSC19d - assert result.status_code == 404 # HP4 - assert result.success is False # HP5 - - @dont_vary_protocol - def test_error(self): - params = {'limit': 'abc'} - result = self.ably.request('GET', self.path, params=params, version=Defaults.protocol_version) - assert isinstance(result, HttpPaginatedResponseSync) # RSC19d - assert result.status_code == 400 # HP4 - assert not result.success - assert result.error_code - assert result.error_message - - def test_headers(self): - key = 'X-Test' - value = 'lorem ipsum' - result = self.ably.request('GET', '/time', headers={key: value}, version=Defaults.protocol_version) - assert result.response.request.headers[key] == value - - # RSC19e - @dont_vary_protocol - def test_timeout(self): - # Timeout - timeout = 0.000001 - ably = AblyRestSync(token="foo", http_request_timeout=timeout) - assert ably.http.http_request_timeout == timeout - with pytest.raises(httpx.ReadTimeout): - ably.request('GET', '/time', version=Defaults.protocol_version) - ably.close() - - default_endpoint = 'https://sandbox-rest.ably.io/time' - fallback_host = 'sandbox-a-fallback.ably-realtime.com' - fallback_endpoint = f'https://{fallback_host}/time' - ably = TestApp.get_ably_rest(fallback_hosts=[fallback_host]) - with respx.mock: - default_route = respx.get(default_endpoint) - fallback_route = respx.get(fallback_endpoint) - headers = { - "Content-Type": "application/json" - } - default_route.side_effect = httpx.ConnectError('') - fallback_route.return_value = httpx.Response(200, headers=headers, text='[123]') - ably.request('GET', '/time', version=Defaults.protocol_version) - ably.close() - - # Bad host, no Fallback - ably = AblyRestSync(key=self.test_vars["keys"][0]["key_str"], - rest_host='some.other.host', - port=self.test_vars["port"], - tls_port=self.test_vars["tls_port"], - tls=self.test_vars["tls"]) - with pytest.raises(httpx.ConnectError): - ably.request('GET', '/time', version=Defaults.protocol_version) - ably.close() - - def test_version(self): - version = "150" # chosen arbitrarily - result = self.ably.request('GET', '/time', "150") - assert result.response.request.headers["X-Ably-Version"] == version diff --git a/test/ably/sync/rest/sync_reststats_test.py b/test/ably/sync/rest/sync_reststats_test.py deleted file mode 100644 index dd2c91bc..00000000 --- a/test/ably/sync/rest/sync_reststats_test.py +++ /dev/null @@ -1,310 +0,0 @@ -from datetime import datetime -from datetime import timedelta -import logging - -import pytest - -from ably.sync.types.stats import Stats -from ably.sync.util.exceptions import AblyException -from ably.sync.http.paginatedresult import PaginatedResultSync - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -class TestRestAppStatsSetup: - __stats_added = False - - def get_params(self): - return { - 'start': self.last_interval, - 'end': self.last_interval, - 'unit': 'minute', - 'limit': 1 - } - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.ably_text = TestApp.get_ably_rest(use_binary_protocol=False) - - self.last_year = datetime.now().year - 1 - self.previous_year = datetime.now().year - 2 - self.last_interval = datetime(self.last_year, 2, 3, 15, 5) - self.previous_interval = datetime(self.previous_year, 2, 3, 15, 5) - previous_year_stats = 120 - stats = [ - { - 'intervalId': Stats.to_interval_id(self.last_interval - timedelta(minutes=2), - 'minute'), - 'inbound': {'realtime': {'messages': {'count': 50, 'data': 5000}}}, - 'outbound': {'realtime': {'messages': {'count': 20, 'data': 2000}}} - }, - { - 'intervalId': Stats.to_interval_id(self.last_interval - timedelta(minutes=1), - 'minute'), - 'inbound': {'realtime': {'messages': {'count': 60, 'data': 6000}}}, - 'outbound': {'realtime': {'messages': {'count': 10, 'data': 1000}}} - }, - { - 'intervalId': Stats.to_interval_id(self.last_interval, 'minute'), - 'inbound': {'realtime': {'messages': {'count': 70, 'data': 7000}}}, - 'outbound': {'realtime': {'messages': {'count': 40, 'data': 4000}}}, - 'persisted': {'presence': {'count': 20, 'data': 2000}}, - 'connections': {'tls': {'peak': 20, 'opened': 10}}, - 'channels': {'peak': 50, 'opened': 30}, - 'apiRequests': {'succeeded': 50, 'failed': 10}, - 'tokenRequests': {'succeeded': 60, 'failed': 20}, - } - ] - - previous_stats = [] - for i in range(previous_year_stats): - previous_stats.append( - { - 'intervalId': Stats.to_interval_id(self.previous_interval - timedelta(minutes=i), - 'minute'), - 'inbound': {'realtime': {'messages': {'count': i}}} - } - ) - # asynctest does not support setUpClass method - if TestRestAppStatsSetup.__stats_added: - return - self.ably.http.post('/stats', body=stats + previous_stats) - TestRestAppStatsSetup.__stats_added = True - - def tearDown(self): - self.ably.close() - self.ably_text.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - - -class TestDirectionForwards(TestRestAppStatsSetup, BaseAsyncTestCase, - metaclass=VaryByProtocolTestsMetaclass): - - def get_params(self): - return { - 'start': self.last_interval - timedelta(minutes=2), - 'end': self.last_interval, - 'unit': 'minute', - 'direction': 'forwards', - 'limit': 1 - } - - def test_stats_are_forward(self): - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.inbound.realtime.all.count"] == 50 - - def test_three_pages(self): - stats_pages = self.ably.stats(**self.get_params()) - assert not stats_pages.is_last() - page2 = stats_pages.next() - page3 = page2.next() - assert page3.items[0].entries["messages.inbound.realtime.all.count"] == 70 - - -class TestDirectionBackwards(TestRestAppStatsSetup, BaseAsyncTestCase, - metaclass=VaryByProtocolTestsMetaclass): - - def get_params(self): - return { - 'end': self.last_interval, - 'unit': 'minute', - 'direction': 'backwards', - 'limit': 1 - } - - def test_stats_are_forward(self): - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.inbound.realtime.all.count"] == 70 - - def test_three_pages(self): - stats_pages = self.ably.stats(**self.get_params()) - assert not stats_pages.is_last() - page2 = stats_pages.next() - page3 = page2.next() - assert not stats_pages.is_last() - assert page3.items[0].entries["messages.inbound.realtime.all.count"] == 50 - - -class TestOnlyLastYear(TestRestAppStatsSetup, BaseAsyncTestCase, - metaclass=VaryByProtocolTestsMetaclass): - - def get_params(self): - return { - 'end': self.last_interval, - 'unit': 'minute', - 'limit': 3 - } - - def test_default_is_backwards(self): - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - assert stats[0].entries["messages.inbound.realtime.messages.count"] == 70 - assert stats[-1].entries["messages.inbound.realtime.messages.count"] == 50 - - -class TestPreviousYear(TestRestAppStatsSetup, BaseAsyncTestCase, - metaclass=VaryByProtocolTestsMetaclass): - - def get_params(self): - return { - 'end': self.previous_interval, - 'unit': 'minute', - } - - def test_default_100_pagination(self): - self.stats_pages = self.ably.stats(**self.get_params()) - stats = self.stats_pages.items - assert len(stats) == 100 - next_page = self.stats_pages.next() - assert len(next_page.items) == 20 - - -class TestRestAppStats(TestRestAppStatsSetup, BaseAsyncTestCase, - metaclass=VaryByProtocolTestsMetaclass): - - @dont_vary_protocol - def test_protocols(self): - stats_pages = self.ably.stats(**self.get_params()) - stats_pages1 = self.ably_text.stats(**self.get_params()) - assert len(stats_pages.items) == len(stats_pages1.items) - - def test_paginated_response(self): - stats_pages = self.ably.stats(**self.get_params()) - assert isinstance(stats_pages, PaginatedResultSync) - assert isinstance(stats_pages.items[0], Stats) - - def test_units(self): - for unit in ['hour', 'day', 'month']: - params = { - 'start': self.last_interval, - 'end': self.last_interval, - 'unit': unit, - 'direction': 'forwards', - 'limit': 1 - } - stats_pages = self.ably.stats(**params) - stat = stats_pages.items[0] - assert len(stats_pages.items) == 1 - assert stat.entries["messages.all.messages.count"] == 50 + 20 + 60 + 10 + 70 + 40 - assert stat.entries["messages.all.messages.data"] == 5000 + 2000 + 6000 + 1000 + 7000 + 4000 - - @dont_vary_protocol - def test_when_argument_start_is_after_end(self): - params = { - 'start': self.last_interval, - 'end': self.last_interval - timedelta(minutes=2), - 'unit': 'minute', - } - with pytest.raises(AblyException, match="'end' parameter has to be greater than or equal to 'start'"): - self.ably.stats(**params) - - @dont_vary_protocol - def test_when_limit_gt_1000(self): - params = { - 'end': self.last_interval, - 'limit': 5000 - } - with pytest.raises(AblyException, match="The maximum allowed limit is 1000"): - self.ably.stats(**params) - - def test_no_arguments(self): - params = { - 'end': self.last_interval, - } - stats_pages = self.ably.stats(**params) - self.stat = stats_pages.items[0] - assert self.stat.unit == 'minute' - - def test_got_1_record(self): - stats_pages = self.ably.stats(**self.get_params()) - assert 1 == len(stats_pages.items), "Expected 1 record" - - def test_return_aggregated_message_data(self): - # returns aggregated message data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.all.messages.count"] == 70 + 40 - assert stat.entries["messages.all.messages.data"] == 7000 + 4000 - - def test_inbound_realtime_all_data(self): - # returns inbound realtime all data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.inbound.realtime.all.count"] == 70 - assert stat.entries["messages.inbound.realtime.all.data"] == 7000 - - def test_inboud_realtime_message_data(self): - # returns inbound realtime message data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.inbound.realtime.messages.count"] == 70 - assert stat.entries["messages.inbound.realtime.messages.data"] == 7000 - - def test_outbound_realtime_all_data(self): - # returns outboud realtime all data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.outbound.realtime.all.count"] == 40 - assert stat.entries["messages.outbound.realtime.all.data"] == 4000 - - def test_persisted_data(self): - # returns persisted presence all data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["messages.persisted.all.count"] == 20 - assert stat.entries["messages.persisted.all.data"] == 2000 - - def test_connections_data(self): - # returns connections all data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["connections.all.peak"] == 20 - assert stat.entries["connections.all.opened"] == 10 - - def test_channels_all_data(self): - # returns channels all data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["channels.peak"] == 50 - assert stat.entries["channels.opened"] == 30 - - def test_api_requests_data(self): - # returns api_requests data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["apiRequests.other.succeeded"] == 50 - assert stat.entries["apiRequests.other.failed"] == 10 - - def test_token_requests(self): - # returns token_requests data - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.entries["apiRequests.tokenRequests.succeeded"] == 60 - assert stat.entries["apiRequests.tokenRequests.failed"] == 20 - - def test_interval(self): - # interval - stats_pages = self.ably.stats(**self.get_params()) - stats = stats_pages.items - stat = stats[0] - assert stat.unit == 'minute' - assert stat.interval_id == self.last_interval.strftime('%Y-%m-%d:%H:%M') - assert stat.interval_time == self.last_interval diff --git a/test/ably/sync/rest/sync_resttime_test.py b/test/ably/sync/rest/sync_resttime_test.py deleted file mode 100644 index 70116864..00000000 --- a/test/ably/sync/rest/sync_resttime_test.py +++ /dev/null @@ -1,43 +0,0 @@ -import time - -import pytest - -from ably.sync import AblyException - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - - -class TestRestTime(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def setUp(self): - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def test_time_accuracy(self): - reported_time = self.ably.time() - actual_time = time.time() * 1000.0 - - seconds = 10 - assert abs(actual_time - reported_time) < seconds * 1000, "Time is not within %s seconds" % seconds - - def test_time_without_key_or_token(self): - reported_time = self.ably.time() - actual_time = time.time() * 1000.0 - - seconds = 10 - assert abs(actual_time - reported_time) < seconds * 1000, "Time is not within %s seconds" % seconds - - @dont_vary_protocol - def test_time_fails_without_valid_host(self): - ably = TestApp.get_ably_rest(key=None, token='foo', rest_host="this.host.does.not.exist") - with pytest.raises(AblyException): - ably.time() - - ably.close() diff --git a/test/ably/sync/rest/sync_resttoken_test.py b/test/ably/sync/rest/sync_resttoken_test.py deleted file mode 100644 index ee3a1562..00000000 --- a/test/ably/sync/rest/sync_resttoken_test.py +++ /dev/null @@ -1,342 +0,0 @@ -import datetime -import json -import logging - -from mock import patch -import pytest - -from ably.sync import AblyException -from ably.sync import AblyRestSync -from ably.sync import Capability -from ably.sync.types.tokendetails import TokenDetails -from ably.sync.types.tokenrequest import TokenRequest - -from test.ably.sync.testapp import TestApp -from test.ably.sync.utils import VaryByProtocolTestsMetaclass, dont_vary_protocol, BaseAsyncTestCase - -log = logging.getLogger(__name__) - - -class TestRestToken(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def server_time(self): - return self.ably.time() - - def setUp(self): - capability = {"*": ["*"]} - self.permit_all = str(Capability(capability)) - self.ably = TestApp.get_ably_rest() - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - def test_request_token_null_params(self): - pre_time = self.server_time() - token_details = self.ably.auth.request_token() - post_time = self.server_time() - assert token_details.token is not None, "Expected token" - assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time + 500, "Unexpected issued time" - assert self.permit_all == str(token_details.capability), "Unexpected capability" - - def test_request_token_explicit_timestamp(self): - pre_time = self.server_time() - token_details = self.ably.auth.request_token(token_params={'timestamp': pre_time}) - post_time = self.server_time() - assert token_details.token is not None, "Expected token" - assert token_details.issued + 300 >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time, "Unexpected issued time" - assert self.permit_all == str(Capability(token_details.capability)), "Unexpected Capability" - - def test_request_token_explicit_invalid_timestamp(self): - request_time = self.server_time() - explicit_timestamp = request_time - 30 * 60 * 1000 - - with pytest.raises(AblyException): - self.ably.auth.request_token(token_params={'timestamp': explicit_timestamp}) - - def test_request_token_with_system_timestamp(self): - pre_time = self.server_time() - token_details = self.ably.auth.request_token(query_time=True) - post_time = self.server_time() - assert token_details.token is not None, "Expected token" - assert token_details.issued >= pre_time, "Unexpected issued time" - assert token_details.issued <= post_time, "Unexpected issued time" - assert self.permit_all == str(Capability(token_details.capability)), "Unexpected Capability" - - def test_request_token_with_duplicate_nonce(self): - request_time = self.server_time() - token_params = { - 'timestamp': request_time, - 'nonce': '1234567890123456' - } - token_details = self.ably.auth.request_token(token_params) - assert token_details.token is not None, "Expected token" - - with pytest.raises(AblyException): - self.ably.auth.request_token(token_params) - - def test_request_token_with_capability_that_subsets_key_capability(self): - capability = Capability({ - "onlythischannel": ["subscribe"] - }) - - token_details = self.ably.auth.request_token( - token_params={'capability': capability}) - - assert token_details is not None - assert token_details.token is not None - assert capability == token_details.capability, "Unexpected capability" - - def test_request_token_with_specified_key(self): - test_vars = TestApp.get_test_vars() - key = test_vars["keys"][1] - token_details = self.ably.auth.request_token( - key_name=key["key_name"], key_secret=key["key_secret"]) - assert token_details.token is not None, "Expected token" - assert key.get("capability") == token_details.capability, "Unexpected capability" - - @dont_vary_protocol - def test_request_token_with_invalid_mac(self): - with pytest.raises(AblyException): - self.ably.auth.request_token(token_params={'mac': "thisisnotavalidmac"}) - - def test_request_token_with_specified_ttl(self): - token_details = self.ably.auth.request_token(token_params={'ttl': 100}) - assert token_details.token is not None, "Expected token" - assert token_details.issued + 100 == token_details.expires, "Unexpected expires" - - @dont_vary_protocol - def test_token_with_excessive_ttl(self): - excessive_ttl = 365 * 24 * 60 * 60 * 1000 - with pytest.raises(AblyException): - self.ably.auth.request_token(token_params={'ttl': excessive_ttl}) - - @dont_vary_protocol - def test_token_generation_with_invalid_ttl(self): - with pytest.raises(AblyException): - self.ably.auth.request_token(token_params={'ttl': -1}) - - def test_token_generation_with_local_time(self): - timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: - self.ably.auth.request_token() - assert local_time.called - assert not server_time.called - - # RSA10k - def test_token_generation_with_server_time(self): - timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: - self.ably.auth.request_token(query_time=True) - assert local_time.call_count == 1 - assert server_time.call_count == 1 - self.ably.auth.request_token(query_time=True) - assert local_time.call_count == 2 - assert server_time.call_count == 1 - - # TD7 - def test_toke_details_from_json(self): - token_details = self.ably.auth.request_token() - token_details_dict = token_details.to_dict() - token_details_str = json.dumps(token_details_dict) - - assert token_details == TokenDetails.from_json(token_details_dict) - assert token_details == TokenDetails.from_json(token_details_str) - - # Issue #71 - @dont_vary_protocol - def test_request_token_float_and_timedelta(self): - lifetime = datetime.timedelta(hours=4) - self.ably.auth.request_token({'ttl': lifetime.total_seconds() * 1000}) - self.ably.auth.request_token({'ttl': lifetime}) - - -class TestCreateTokenRequest(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): - - def setUp(self): - self.ably = TestApp.get_ably_rest() - self.key_name = self.ably.options.key_name - self.key_secret = self.ably.options.key_secret - - def tearDown(self): - self.ably.close() - - def per_protocol_setup(self, use_binary_protocol): - self.ably.options.use_binary_protocol = use_binary_protocol - self.use_binary_protocol = use_binary_protocol - - @dont_vary_protocol - def test_key_name_and_secret_are_required(self): - ably = TestApp.get_ably_rest(key=None, token='not a real token') - with pytest.raises(AblyException, match="40101 401 No key specified"): - ably.auth.create_token_request() - with pytest.raises(AblyException, match="40101 401 No key specified"): - ably.auth.create_token_request(key_name=self.key_name) - with pytest.raises(AblyException, match="40101 401 No key specified"): - ably.auth.create_token_request(key_secret=self.key_secret) - - @dont_vary_protocol - def test_with_local_time(self): - timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: - self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, query_time=False) - assert local_time.called - assert not server_time.called - - # RSA10k - @dont_vary_protocol - def test_with_server_time(self): - timestamp = self.ably.auth._timestamp - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time,\ - patch('ably.sync.rest.auth.AuthSync._timestamp', wraps=timestamp) as local_time: - self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, query_time=True) - assert local_time.call_count == 1 - assert server_time.call_count == 1 - self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, query_time=True) - assert local_time.call_count == 2 - assert server_time.call_count == 1 - - def test_token_request_can_be_used_to_get_a_token(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert isinstance(token_request, TokenRequest) - - def auth_callback(token_params): - return token_request - - ably = TestApp.get_ably_rest(key=None, - auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) - - token = ably.auth.authorize() - assert isinstance(token, TokenDetails) - ably.close() - - def test_token_request_dict_can_be_used_to_get_a_token(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert isinstance(token_request, TokenRequest) - - def auth_callback(token_params): - return token_request.to_dict() - - ably = TestApp.get_ably_rest(key=None, - auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) - - token = ably.auth.authorize() - assert isinstance(token, TokenDetails) - ably.close() - - # TE6 - @dont_vary_protocol - def test_token_request_from_json(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert isinstance(token_request, TokenRequest) - - token_request_dict = token_request.to_dict() - assert token_request == TokenRequest.from_json(token_request_dict) - - token_request_str = json.dumps(token_request_dict) - assert token_request == TokenRequest.from_json(token_request_str) - - @dont_vary_protocol - def test_nonce_is_random_and_longer_than_15_characters(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert len(token_request.nonce) > 15 - - another_token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert len(another_token_request.nonce) > 15 - - assert token_request.nonce != another_token_request.nonce - - # RSA5 - @dont_vary_protocol - def test_ttl_is_optional_and_specified_in_ms(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert token_request.ttl is None - - # RSA6 - @dont_vary_protocol - def test_capability_is_optional(self): - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret) - assert token_request.capability is None - - @dont_vary_protocol - def test_accept_all_token_params(self): - token_params = { - 'ttl': 1000, - 'capability': Capability({'channel': ['publish']}), - 'client_id': 'a_id', - 'timestamp': 1000, - 'nonce': 'a_nonce', - } - token_request = self.ably.auth.create_token_request( - token_params, - key_name=self.key_name, key_secret=self.key_secret, - ) - assert token_request.ttl == token_params['ttl'] - assert token_request.capability == str(token_params['capability']) - assert token_request.client_id == token_params['client_id'] - assert token_request.timestamp == token_params['timestamp'] - assert token_request.nonce == token_params['nonce'] - - def test_capability(self): - capability = Capability({'channel': ['publish']}) - token_request = self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, - token_params={'capability': capability}) - assert token_request.capability == str(capability) - - def auth_callback(token_params): - return token_request - - ably = TestApp.get_ably_rest(key=None, auth_callback=auth_callback, - use_binary_protocol=self.use_binary_protocol) - - token = ably.auth.authorize() - - assert str(token.capability) == str(capability) - ably.close() - - @dont_vary_protocol - def test_hmac(self): - ably = AblyRestSync(key_name='a_key_name', key_secret='a_secret') - token_params = { - 'ttl': 1000, - 'nonce': 'abcde100', - 'client_id': 'a_id', - 'timestamp': 1000, - } - token_request = ably.auth.create_token_request( - token_params, key_secret='a_secret', key_name='a_key_name') - assert token_request.mac == 'sYkCH0Un+WgzI7/Nhy0BoQIKq9HmjKynCRs4E3qAbGQ=' - ably.close() - - # AO2g - @dont_vary_protocol - def test_query_server_time(self): - with patch('ably.sync.rest.rest.AblyRestSync.time', wraps=self.ably.time) as server_time: - self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, query_time=True) - assert server_time.call_count == 1 - - self.ably.auth.create_token_request( - key_name=self.key_name, key_secret=self.key_secret, query_time=False) - assert server_time.call_count == 1 diff --git a/test/ably/sync/testapp.py b/test/ably/sync/testapp.py deleted file mode 100644 index 0947296f..00000000 --- a/test/ably/sync/testapp.py +++ /dev/null @@ -1,115 +0,0 @@ -import json -import os -import logging - -from ably.sync.rest.rest import AblyRestSync -from ably.sync.types.capability import Capability -from ably.sync.types.options import Options -from ably.sync.util.exceptions import AblyException -from ably.sync.realtime.realtime import AblyRealtime - -log = logging.getLogger(__name__) - -with open(os.path.dirname(__file__) + '/../../assets/testAppSpec.json', 'r') as f: - app_spec_local = json.loads(f.read()) - -tls = (os.environ.get('ABLY_TLS') or "true").lower() == "true" -rest_host = os.environ.get('ABLY_REST_HOST', 'sandbox-rest.ably.io') -realtime_host = os.environ.get('ABLY_REALTIME_HOST', 'sandbox-realtime.ably.io') - -environment = os.environ.get('ABLY_ENV', 'sandbox') - -port = 80 -tls_port = 443 - -if rest_host and not rest_host.endswith("rest.ably.io"): - tls = tls and rest_host != "localhost" - port = 8080 - tls_port = 8081 - - -ably = AblyRestSync(token='not_a_real_token', - port=port, tls_port=tls_port, tls=tls, - environment=environment, - use_binary_protocol=False) - - -class TestApp: - __test_vars = None - - @staticmethod - def get_test_vars(): - if not TestApp.__test_vars: - r = ably.http.post("/apps", body=app_spec_local, skip_auth=True) - AblyException.raise_for_response(r) - - app_spec = r.json() - - app_id = app_spec.get("appId", "") - - test_vars = { - "app_id": app_id, - "host": rest_host, - "port": port, - "tls_port": tls_port, - "tls": tls, - "environment": environment, - "realtime_host": realtime_host, - "keys": [{ - "key_name": "%s.%s" % (app_id, k.get("id", "")), - "key_secret": k.get("value", ""), - "key_str": "%s.%s:%s" % (app_id, k.get("id", ""), k.get("value", "")), - "capability": Capability(json.loads(k.get("capability", "{}"))), - } for k in app_spec.get("keys", [])] - } - - TestApp.__test_vars = test_vars - log.debug([(app_id, k.get("id", ""), k.get("value", "")) - for k in app_spec.get("keys", [])]) - - return TestApp.__test_vars - - @staticmethod - def get_ably_rest(**kw): - test_vars = TestApp.get_test_vars() - options = TestApp.get_options(test_vars, **kw) - options.update(kw) - return AblyRestSync(**options) - - @staticmethod - def get_ably_realtime(**kw): - test_vars = TestApp.get_test_vars() - options = TestApp.get_options(test_vars, **kw) - return AblyRealtime(**options) - - @staticmethod - def get_options(test_vars, **kwargs): - options = { - 'port': test_vars["port"], - 'tls_port': test_vars["tls_port"], - 'tls': test_vars["tls"], - 'environment': test_vars["environment"], - } - auth_methods = ["auth_url", "auth_callback", "token", "token_details", "key"] - if not any(x in kwargs for x in auth_methods): - options["key"] = test_vars["keys"][0]["key_str"] - - if any(x in kwargs for x in ["rest_host", "realtime_host"]): - options["environment"] = None - - options.update(kwargs) - - return options - - @staticmethod - def clear_test_vars(): - test_vars = TestApp.__test_vars - options = Options(key=test_vars["keys"][0]["key_str"]) - options.rest_host = test_vars["host"] - options.port = test_vars["port"] - options.tls_port = test_vars["tls_port"] - options.tls = test_vars["tls"] - ably = TestApp.get_ably_rest() - ably.http.delete('/apps/' + test_vars['app_id']) - TestApp.__test_vars = None - ably.close() diff --git a/test/ably/sync/utils.py b/test/ably/sync/utils.py deleted file mode 100644 index a45a7b39..00000000 --- a/test/ably/sync/utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import functools -import os -import random -import string -import unittest -import sys - -if sys.version_info >= (3, 8): - from unittest import IsolatedAsyncioTestCase -else: - from async_case import IsolatedAsyncioTestCase - -import msgpack -import mock -import respx -from httpx import Response - -from ably.sync.http.http import HttpSync - - -class BaseTestCase(unittest.TestCase): - - def respx_add_empty_msg_pack(self, url, method='GET'): - respx.route(method=method, url=url).return_value = Response( - status_code=200, - headers={'content-type': 'application/x-msgpack'}, - content=msgpack.packb({}) - ) - - @classmethod - def get_channel_name(cls, prefix=''): - return prefix + random_string(10) - - @classmethod - def get_channel(cls, prefix=''): - name = cls.get_channel_name(prefix) - return cls.ably.channels.get(name) - - -class BaseAsyncTestCase(IsolatedAsyncioTestCase): - - def respx_add_empty_msg_pack(self, url, method='GET'): - respx.route(method=method, url=url).return_value = Response( - status_code=200, - headers={'content-type': 'application/x-msgpack'}, - content=msgpack.packb({}) - ) - - @classmethod - def get_channel_name(cls, prefix=''): - return prefix + random_string(10) - - def get_channel(self, prefix=''): - name = self.get_channel_name(prefix) - return self.ably.channels.get(name) - - -def assert_responses_type(protocol): - """ - This is a decorator to check if we retrieved responses with the correct protocol. - usage: - - @assert_responses_type('json') - def test_something(self): - ... - - this will check if all responses received during the test will be in the format - json. - supports json and msgpack - """ - responses = [] - - def patch(): - original = HttpSync.make_request - - def fake_make_request(self, *args, **kwargs): - response = original(self, *args, **kwargs) - responses.append(response) - return response - - patcher = mock.patch.object(HttpSync, 'make_request', fake_make_request) - patcher.start() - return patcher - - def unpatch(patcher): - patcher.stop() - - def test_decorator(fn): - @functools.wraps(fn) - def test_decorated(self, *args, **kwargs): - patcher = patch() - fn(self, *args, **kwargs) - unpatch(patcher) - - assert len(responses) >= 1, \ - "If your test doesn't make any requests, use the @dont_vary_protocol decorator" - - for response in responses: - # In HTTP/2 some header fields are optional in case of 204 status code - if protocol == 'json': - if response.status_code != 204: - assert response.headers['content-type'] == 'application/json' - if response.content: - response.json() - else: - if response.status_code != 204: - assert response.headers['content-type'] == 'application/x-msgpack' - if response.content: - msgpack.unpackb(response.content) - - return test_decorated - - return test_decorator - - -class VaryByProtocolTestsMetaclass(type): - """ - Metaclass to run tests in more than one protocol. - Usage: - * set this as metaclass of the TestCase class - * create the following method: - def per_protocol_setup(self, use_binary_protocol): - # do something here that will run before each test. - * now every test will run twice and before test is run per_protocol_setup - is called - * exclude tests with the @dont_vary_protocol decorator - """ - - def __new__(cls, clsname, bases, dct): - for key, value in tuple(dct.items()): - if key.startswith('test') and not getattr(value, 'dont_vary_protocol', - False): - wrapper_bin = cls.wrap_as('bin', key, value) - wrapper_text = cls.wrap_as('text', key, value) - - dct[key + '_bin'] = wrapper_bin - dct[key + '_text'] = wrapper_text - del dct[key] - - return super().__new__(cls, clsname, bases, dct) - - @staticmethod - def wrap_as(ttype, old_name, old_func): - expected_content = {'bin': 'msgpack', 'text': 'json'} - - @assert_responses_type(expected_content[ttype]) - def wrapper(self): - if hasattr(self, 'per_protocol_setup'): - self.per_protocol_setup(ttype == 'bin') - old_func(self) - - wrapper.__name__ = old_name + '_' + ttype - return wrapper - - -def dont_vary_protocol(func): - func.dont_vary_protocol = True - return func - - -def random_string(length, alphabet=string.ascii_letters): - return ''.join([random.choice(alphabet) for x in range(length)]) - - -def new_dict(src, **kw): - new = src.copy() - new.update(kw) - return new - - -def get_random_key(d): - return random.choice(list(d)) - - -def get_submodule_dir(filepath): - root_dir = os.path.dirname(filepath) - while True: - if os.path.exists(os.path.join(root_dir, 'submodules')): - return os.path.join(root_dir, 'submodules') - root_dir = os.path.dirname(root_dir) From 01aefca255b3dd06b66f14a3795af467571730d3 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:08:35 +0530 Subject: [PATCH 39/52] uncommented restcrypto test file --- test/ably/rest/restcrypto_test.py | 528 +++++++++++++++--------------- 1 file changed, 264 insertions(+), 264 deletions(-) diff --git a/test/ably/rest/restcrypto_test.py b/test/ably/rest/restcrypto_test.py index 3dd89bc2..18bf69ac 100644 --- a/test/ably/rest/restcrypto_test.py +++ b/test/ably/rest/restcrypto_test.py @@ -1,264 +1,264 @@ -# import json -# import os -# import logging -# import base64 -# -# import pytest -# -# from ably import AblyException -# from ably.types.message import Message -# from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params -# -# from Crypto import Random -# -# from test.ably.testapp import TestApp -# from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase -# -# log = logging.getLogger(__name__) -# -# -# class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): -# -# async def asyncSetUp(self): -# self.test_vars = await TestApp.get_test_vars() -# self.ably = await TestApp.get_ably_rest() -# self.ably2 = await TestApp.get_ably_rest() -# -# async def asyncTearDown(self): -# await self.ably.close() -# await self.ably2.close() -# -# def per_protocol_setup(self, use_binary_protocol): -# # This will be called every test that vary by protocol for each protocol -# self.ably.options.use_binary_protocol = use_binary_protocol -# self.ably2.options.use_binary_protocol = use_binary_protocol -# self.use_binary_protocol = use_binary_protocol -# -# @dont_vary_protocol -# def test_cbc_channel_cipher(self): -# key = ( -# b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' -# b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') -# -# iv = ( -# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' -# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') -# -# log.debug("KEY_LEN: %d" % len(key)) -# log.debug("IV_LEN: %d" % len(iv)) -# cipher = get_cipher({'key': key, 'iv': iv}) -# -# plaintext = b"The quick brown fox" -# expected_ciphertext = ( -# b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' -# b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' -# b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' -# b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' -# b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' -# b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') -# -# actual_ciphertext = cipher.encrypt(plaintext) -# -# assert expected_ciphertext == actual_ciphertext -# -# async def test_crypto_publish(self): -# channel_name = self.get_channel_name('persisted:crypto_publish_text') -# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# history = await publish0.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String)" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_publish_256(self): -# rndfile = Random.new() -# key = rndfile.read(32) -# channel_name = 'persisted:crypto_publish_text_256' -# channel_name += '_bin' if self.use_binary_protocol else '_text' -# -# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# history = await publish0.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String)" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_publish_key_mismatch(self): -# channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') -# -# publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# with pytest.raises(AblyException) as excinfo: -# await rx_channel.history() -# -# message = excinfo.value.message -# assert 'invalid-padding' == message or "codec can't decode" in message -# -# async def test_crypto_send_unencrypted(self): -# channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') -# publish0 = self.ably.channels[channel_name] -# -# await publish0.publish("publish3", "This is a string message payload") -# await publish0.publish("publish4", b"This is a byte[] message payload") -# await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) -# await publish0.publish("publish6", ["This is a JSONArray message payload"]) -# -# rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) -# -# history = await rx_channel.history() -# messages = history.items -# assert messages is not None, "Expected non-None messages" -# assert 4 == len(messages), "Expected 4 messages" -# -# message_contents = dict((m.name, m.data) for m in messages) -# log.debug("message_contents: %s" % str(message_contents)) -# -# assert "This is a string message payload" == message_contents["publish3"],\ -# "Expect publish3 to be expected String" -# -# assert b"This is a byte[] message payload" == message_contents["publish4"],\ -# "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) -# -# assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ -# "Expect publish5 to be expected JSONObject" -# -# assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ -# "Expect publish6 to be expected JSONObject" -# -# async def test_crypto_encrypted_unhandled(self): -# channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') -# key = b'0123456789abcdef' -# data = 'foobar' -# publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) -# -# await publish0.publish("publish0", data) -# -# rx_channel = self.ably2.channels[channel_name] -# history = await rx_channel.history() -# message = history.items[0] -# cipher = get_cipher(get_default_params({'key': key})) -# assert cipher.decrypt(message.data).decode() == data -# assert message.encoding == 'utf-8/cipher+aes-128-cbc' -# -# @dont_vary_protocol -# def test_cipher_params(self): -# params = CipherParams(secret_key='0123456789abcdef') -# assert params.algorithm == 'AES' -# assert params.mode == 'CBC' -# assert params.key_length == 128 -# -# params = CipherParams(secret_key='0123456789abcdef' * 2) -# assert params.algorithm == 'AES' -# assert params.mode == 'CBC' -# assert params.key_length == 256 -# -# -# class AbstractTestCryptoWithFixture: -# -# @classmethod -# def setUpClass(cls): -# resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file -# with open(resources_path, 'r') as f: -# cls.fixture = json.loads(f.read()) -# cls.params = { -# 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), -# 'mode': cls.fixture['mode'], -# 'algorithm': cls.fixture['algorithm'], -# 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), -# } -# cls.cipher_params = CipherParams(**cls.params) -# cls.cipher = get_cipher(cls.cipher_params) -# cls.items = cls.fixture['items'] -# -# def get_encoded(self, encoded_item): -# if encoded_item.get('encoding') == 'base64': -# return base64.b64decode(encoded_item['data'].encode('ascii')) -# elif encoded_item.get('encoding') == 'json': -# return json.loads(encoded_item['data']) -# return encoded_item['data'] -# -# # TM3 -# def test_decode(self): -# for item in self.items: -# assert item['encoded']['name'] == item['encrypted']['name'] -# message = Message.from_encoded(item['encrypted'], self.cipher) -# assert message.encoding == '' -# expected_data = self.get_encoded(item['encoded']) -# assert expected_data == message.data -# -# # TM3 -# def test_decode_array(self): -# items_encrypted = [item['encrypted'] for item in self.items] -# messages = Message.from_encoded_array(items_encrypted, self.cipher) -# for i, message in enumerate(messages): -# assert message.encoding == '' -# expected_data = self.get_encoded(self.items[i]['encoded']) -# assert expected_data == message.data -# -# def test_encode(self): -# for item in self.items: -# # need to reset iv -# self.cipher_params = CipherParams(**self.params) -# self.cipher = get_cipher(self.cipher_params) -# data = self.get_encoded(item['encoded']) -# expected = item['encrypted'] -# message = Message(item['encoded']['name'], data) -# message.encrypt(self.cipher) -# as_dict = message.as_dict() -# assert as_dict['data'] == expected['data'] -# assert as_dict['encoding'] == expected['encoding'] -# -# -# class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): -# fixture_file = 'crypto-data-128.json' -# -# -# class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): -# fixture_file = 'crypto-data-256.json' +import json +import os +import logging +import base64 + +import pytest + +from ably import AblyException +from ably.types.message import Message +from ably.util.crypto import CipherParams, get_cipher, generate_random_key, get_default_params + +from Crypto import Random + +from test.ably.testapp import TestApp +from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase + +log = logging.getLogger(__name__) + + +class TestRestCrypto(BaseAsyncTestCase, metaclass=VaryByProtocolTestsMetaclass): + + async def asyncSetUp(self): + self.test_vars = await TestApp.get_test_vars() + self.ably = await TestApp.get_ably_rest() + self.ably2 = await TestApp.get_ably_rest() + + async def asyncTearDown(self): + await self.ably.close() + await self.ably2.close() + + def per_protocol_setup(self, use_binary_protocol): + # This will be called every test that vary by protocol for each protocol + self.ably.options.use_binary_protocol = use_binary_protocol + self.ably2.options.use_binary_protocol = use_binary_protocol + self.use_binary_protocol = use_binary_protocol + + @dont_vary_protocol + def test_cbc_channel_cipher(self): + key = ( + b'\x93\xe3\x5c\xc9\x77\x53\xfd\x1a' + b'\x79\xb4\xd8\x84\xe7\xdc\xfd\xdf') + + iv = ( + b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' + b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0') + + log.debug("KEY_LEN: %d" % len(key)) + log.debug("IV_LEN: %d" % len(iv)) + cipher = get_cipher({'key': key, 'iv': iv}) + + plaintext = b"The quick brown fox" + expected_ciphertext = ( + b'\x28\x4c\xe4\x8d\x4b\xdc\x9d\x42' + b'\x8a\x77\x6b\x53\x2d\xc7\xb5\xc0' + b'\x83\x5c\xcf\xce\x0c\xfd\xbe\x37' + b'\xb7\x92\x12\x04\x1d\x45\x68\xa4' + b'\xdf\x7f\x6e\x38\x17\x4a\xff\x50' + b'\x73\x23\xbb\xca\x16\xb0\xe2\x84') + + actual_ciphertext = cipher.encrypt(plaintext) + + assert expected_ciphertext == actual_ciphertext + + async def test_crypto_publish(self): + channel_name = self.get_channel_name('persisted:crypto_publish_text') + publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) + + await publish0.publish("publish3", "This is a string message payload") + await publish0.publish("publish4", b"This is a byte[] message payload") + await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) + await publish0.publish("publish6", ["This is a JSONArray message payload"]) + + history = await publish0.history() + messages = history.items + assert messages is not None, "Expected non-None messages" + assert 4 == len(messages), "Expected 4 messages" + + message_contents = dict((m.name, m.data) for m in messages) + log.debug("message_contents: %s" % str(message_contents)) + + assert "This is a string message payload" == message_contents["publish3"],\ + "Expect publish3 to be expected String)" + + assert b"This is a byte[] message payload" == message_contents["publish4"],\ + "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) + + assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ + "Expect publish5 to be expected JSONObject" + + assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ + "Expect publish6 to be expected JSONObject" + + async def test_crypto_publish_256(self): + rndfile = Random.new() + key = rndfile.read(32) + channel_name = 'persisted:crypto_publish_text_256' + channel_name += '_bin' if self.use_binary_protocol else '_text' + + publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) + + await publish0.publish("publish3", "This is a string message payload") + await publish0.publish("publish4", b"This is a byte[] message payload") + await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) + await publish0.publish("publish6", ["This is a JSONArray message payload"]) + + history = await publish0.history() + messages = history.items + assert messages is not None, "Expected non-None messages" + assert 4 == len(messages), "Expected 4 messages" + + message_contents = dict((m.name, m.data) for m in messages) + log.debug("message_contents: %s" % str(message_contents)) + + assert "This is a string message payload" == message_contents["publish3"],\ + "Expect publish3 to be expected String)" + + assert b"This is a byte[] message payload" == message_contents["publish4"],\ + "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) + + assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ + "Expect publish5 to be expected JSONObject" + + assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ + "Expect publish6 to be expected JSONObject" + + async def test_crypto_publish_key_mismatch(self): + channel_name = self.get_channel_name('persisted:crypto_publish_key_mismatch') + + publish0 = self.ably.channels.get(channel_name, cipher={'key': generate_random_key()}) + + await publish0.publish("publish3", "This is a string message payload") + await publish0.publish("publish4", b"This is a byte[] message payload") + await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) + await publish0.publish("publish6", ["This is a JSONArray message payload"]) + + rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) + + with pytest.raises(AblyException) as excinfo: + await rx_channel.history() + + message = excinfo.value.message + assert 'invalid-padding' == message or "codec can't decode" in message + + async def test_crypto_send_unencrypted(self): + channel_name = self.get_channel_name('persisted:crypto_send_unencrypted') + publish0 = self.ably.channels[channel_name] + + await publish0.publish("publish3", "This is a string message payload") + await publish0.publish("publish4", b"This is a byte[] message payload") + await publish0.publish("publish5", {"test": "This is a JSONObject message payload"}) + await publish0.publish("publish6", ["This is a JSONArray message payload"]) + + rx_channel = self.ably2.channels.get(channel_name, cipher={'key': generate_random_key()}) + + history = await rx_channel.history() + messages = history.items + assert messages is not None, "Expected non-None messages" + assert 4 == len(messages), "Expected 4 messages" + + message_contents = dict((m.name, m.data) for m in messages) + log.debug("message_contents: %s" % str(message_contents)) + + assert "This is a string message payload" == message_contents["publish3"],\ + "Expect publish3 to be expected String" + + assert b"This is a byte[] message payload" == message_contents["publish4"],\ + "Expect publish4 to be expected byte[]. Actual: %s" % str(message_contents['publish4']) + + assert {"test": "This is a JSONObject message payload"} == message_contents["publish5"],\ + "Expect publish5 to be expected JSONObject" + + assert ["This is a JSONArray message payload"] == message_contents["publish6"],\ + "Expect publish6 to be expected JSONObject" + + async def test_crypto_encrypted_unhandled(self): + channel_name = self.get_channel_name('persisted:crypto_send_encrypted_unhandled') + key = b'0123456789abcdef' + data = 'foobar' + publish0 = self.ably.channels.get(channel_name, cipher={'key': key}) + + await publish0.publish("publish0", data) + + rx_channel = self.ably2.channels[channel_name] + history = await rx_channel.history() + message = history.items[0] + cipher = get_cipher(get_default_params({'key': key})) + assert cipher.decrypt(message.data).decode() == data + assert message.encoding == 'utf-8/cipher+aes-128-cbc' + + @dont_vary_protocol + def test_cipher_params(self): + params = CipherParams(secret_key='0123456789abcdef') + assert params.algorithm == 'AES' + assert params.mode == 'CBC' + assert params.key_length == 128 + + params = CipherParams(secret_key='0123456789abcdef' * 2) + assert params.algorithm == 'AES' + assert params.mode == 'CBC' + assert params.key_length == 256 + + +class AbstractTestCryptoWithFixture: + + @classmethod + def setUpClass(cls): + resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file + with open(resources_path, 'r') as f: + cls.fixture = json.loads(f.read()) + cls.params = { + 'secret_key': base64.b64decode(cls.fixture['key'].encode('ascii')), + 'mode': cls.fixture['mode'], + 'algorithm': cls.fixture['algorithm'], + 'iv': base64.b64decode(cls.fixture['iv'].encode('ascii')), + } + cls.cipher_params = CipherParams(**cls.params) + cls.cipher = get_cipher(cls.cipher_params) + cls.items = cls.fixture['items'] + + def get_encoded(self, encoded_item): + if encoded_item.get('encoding') == 'base64': + return base64.b64decode(encoded_item['data'].encode('ascii')) + elif encoded_item.get('encoding') == 'json': + return json.loads(encoded_item['data']) + return encoded_item['data'] + + # TM3 + def test_decode(self): + for item in self.items: + assert item['encoded']['name'] == item['encrypted']['name'] + message = Message.from_encoded(item['encrypted'], self.cipher) + assert message.encoding == '' + expected_data = self.get_encoded(item['encoded']) + assert expected_data == message.data + + # TM3 + def test_decode_array(self): + items_encrypted = [item['encrypted'] for item in self.items] + messages = Message.from_encoded_array(items_encrypted, self.cipher) + for i, message in enumerate(messages): + assert message.encoding == '' + expected_data = self.get_encoded(self.items[i]['encoded']) + assert expected_data == message.data + + def test_encode(self): + for item in self.items: + # need to reset iv + self.cipher_params = CipherParams(**self.params) + self.cipher = get_cipher(self.cipher_params) + data = self.get_encoded(item['encoded']) + expected = item['encrypted'] + message = Message(item['encoded']['name'], data) + message.encrypt(self.cipher) + as_dict = message.as_dict() + assert as_dict['data'] == expected['data'] + assert as_dict['encoding'] == expected['encoding'] + + +class TestCryptoWithFixture128(AbstractTestCryptoWithFixture, BaseTestCase): + fixture_file = 'crypto-data-128.json' + + +class TestCryptoWithFixture256(AbstractTestCryptoWithFixture, BaseTestCase): + fixture_file = 'crypto-data-256.json' From 599cc2aabc4efabaf0a116abe5a51b152f63f6b6 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:10:34 +0530 Subject: [PATCH 40/52] Removed uncessary type signature from unasync generator --- unasync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unasync.py b/unasync.py index 7958682e..b644ee23 100644 --- a/unasync.py +++ b/unasync.py @@ -218,7 +218,7 @@ def unasync_files(fpath_list, rules): found_rule._unasync_file(f) -def find_files(dir_path, file_name_regex) -> list[str]: +def find_files(dir_path, file_name_regex): return glob.glob(os.path.join(dir_path, "**", file_name_regex), recursive=True) From 16e86d9d40860715c4bab578f5f15e41bf11ce83 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:22:52 +0530 Subject: [PATCH 41/52] Fixed crypto test for robust submodules path --- test/ably/rest/restcrypto_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ably/rest/restcrypto_test.py b/test/ably/rest/restcrypto_test.py index 18bf69ac..b6ea577b 100644 --- a/test/ably/rest/restcrypto_test.py +++ b/test/ably/rest/restcrypto_test.py @@ -11,6 +11,7 @@ from Crypto import Random +from test.ably import utils from test.ably.testapp import TestApp from test.ably.utils import dont_vary_protocol, VaryByProtocolTestsMetaclass, BaseTestCase, BaseAsyncTestCase @@ -204,7 +205,7 @@ class AbstractTestCryptoWithFixture: @classmethod def setUpClass(cls): - resources_path = os.path.dirname(__file__) + '/../../../submodules/test-resources/%s' % cls.fixture_file + resources_path = os.path.join(utils.get_submodule_dir(__file__), 'test-resources', cls.fixture_file) with open(resources_path, 'r') as f: cls.fixture = json.loads(f.read()) cls.params = { From 27306487fab307676265ff6a488b5009466ad85c Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Fri, 6 Oct 2023 17:23:11 +0530 Subject: [PATCH 42/52] updated readme for new sync api --- UPDATING.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/UPDATING.md b/UPDATING.md index b30a7f94..c655b5b9 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -72,6 +72,7 @@ These include: - Deprecation of support for Python versions 3.4, 3.5 and 3.6 - New, asynchronous API + - Deprecated synchronous API ### Deprecation of Python 3.4, 3.5 and 3.6 @@ -85,6 +86,26 @@ To see which versions of Python we test the SDK against, please look at our The 1.2.0 version introduces a breaking change, which changes the way of interacting with the SDK from synchronous to asynchronous, using [the `asyncio` foundational library](https://docs.python.org/3.7/library/asyncio.html) to provide support for `async`/`await` syntax. Because of this breaking change, every call that interacts with the Ably REST API must be refactored to this asynchronous way. +Important Update: +- If you want to keep using old synchronous style API, import `AblyRestSync` client instead. +- This is applicable only for Ably REST APIs. + +```python +from ably.sync import AblyRestSync + +def main(): + ably = AblyRestSync('api:key', sync_enabled=True) + channel = ably.channels.get("channel_name") + channel.publish('event', 'message') + +if __name__ == "__main__": + main() +``` +- To use old `AblyRest` class, but with `sync` style API. Import it as, +```python +from ably.sync import AblyRestSync as AblyRest +``` + #### Publishing Messages This old style, synchronous example: @@ -253,4 +274,4 @@ Must now be replaced with this new style, asynchronous form: ```python await client.time() await client.close() -``` +``` \ No newline at end of file From 8218a707a81593fcc2aaf3f6abec3eb18f1b732b Mon Sep 17 00:00:00 2001 From: sachin shinde Date: Mon, 9 Oct 2023 18:55:40 +0530 Subject: [PATCH 43/52] Apply suggestions from code review Co-authored-by: Owen Pearson <48608556+owenpearson@users.noreply.github.com> --- UPDATING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/UPDATING.md b/UPDATING.md index c655b5b9..cddda023 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -94,7 +94,7 @@ Important Update: from ably.sync import AblyRestSync def main(): - ably = AblyRestSync('api:key', sync_enabled=True) + ably = AblyRestSync('api:key') channel = ably.channels.get("channel_name") channel.publish('event', 'message') From 3c057da666484a8f9407c57c22d7ebd41d55618e Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Mon, 9 Oct 2023 19:04:00 +0530 Subject: [PATCH 44/52] Added idea and ably sync packages to gitignore file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0d07b9f2..90697255 100644 --- a/.gitignore +++ b/.gitignore @@ -54,4 +54,5 @@ app_spec.pkl ably/types/options.py.orig test/ably/restsetup.py.orig -.idea/**/* \ No newline at end of file +.idea/**/* +**/ably/sync/*** From ba6f952069443d27bfc8f57e1966ec038a8587b3 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Mon, 9 Oct 2023 19:26:47 +0530 Subject: [PATCH 45/52] Refactored classes to be renamed in the list of rename_classes --- unasync.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/unasync.py b/unasync.py index b644ee23..b33f7274 100644 --- a/unasync.py +++ b/unasync.py @@ -229,15 +229,21 @@ def find_files(dir_path, file_name_regex): _IMPORTS_REPLACE["ably"] = "ably.sync" -_CLASS_RENAME["AblyRest"] = "AblyRestSync" -_CLASS_RENAME["Push"] = "PushSync" -_CLASS_RENAME["PushAdmin"] = "PushAdminSync" -_CLASS_RENAME["Channel"] = "ChannelSync" -_CLASS_RENAME["Channels"] = "ChannelsSync" -_CLASS_RENAME["Auth"] = "AuthSync" -_CLASS_RENAME["Http"] = "HttpSync" -_CLASS_RENAME["PaginatedResult"] = "PaginatedResultSync" -_CLASS_RENAME["HttpPaginatedResponse"] = "HttpPaginatedResponseSync" +rename_classes = [ + "AblyRest", + "Push", + "PushAdmin", + "Channel", + "Channels", + "Auth", + "Http", + "PaginatedResult", + "HttpPaginatedResponse" +] + +# here... +for class_name in rename_classes: + _CLASS_RENAME[class_name] = f"{class_name}Sync" _STRING_REPLACE["Auth"] = "AuthSync" From a4e510520519a61c84295af8db549d4ba9476048 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Mon, 9 Oct 2023 19:38:06 +0530 Subject: [PATCH 46/52] Moved unasync script under scripts directory, updated pyproject.toml --- .github/workflows/check.yml | 2 +- ably/scripts/__init__.py | 0 unasync.py => ably/scripts/unasync.py | 103 +++++++++++++------------- pyproject.toml | 3 + 4 files changed, 56 insertions(+), 52 deletions(-) create mode 100644 ably/scripts/__init__.py rename unasync.py => ably/scripts/unasync.py (76%) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index ddf6a644..4b70e335 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -36,6 +36,6 @@ jobs: - name: Lint with flake8 run: poetry run flake8 - name: Generate rest sync code and tests - run: poetry run python unasync.py + run: poetry run unasync - name: Test with pytest run: poetry run pytest diff --git a/ably/scripts/__init__.py b/ably/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unasync.py b/ably/scripts/unasync.py similarity index 76% rename from unasync.py rename to ably/scripts/unasync.py index b33f7274..c4c8e57f 100644 --- a/unasync.py +++ b/ably/scripts/unasync.py @@ -222,72 +222,73 @@ def find_files(dir_path, file_name_regex): return glob.glob(os.path.join(dir_path, "**", file_name_regex), recursive=True) -# Source files ========================================== +def run(): + # Source files ========================================== -_TOKEN_REPLACE["AsyncClient"] = "Client" -_TOKEN_REPLACE["aclose"] = "close" + _TOKEN_REPLACE["AsyncClient"] = "Client" + _TOKEN_REPLACE["aclose"] = "close" -_IMPORTS_REPLACE["ably"] = "ably.sync" + _IMPORTS_REPLACE["ably"] = "ably.sync" -rename_classes = [ - "AblyRest", - "Push", - "PushAdmin", - "Channel", - "Channels", - "Auth", - "Http", - "PaginatedResult", - "HttpPaginatedResponse" -] + rename_classes = [ + "AblyRest", + "Push", + "PushAdmin", + "Channel", + "Channels", + "Auth", + "Http", + "PaginatedResult", + "HttpPaginatedResponse" + ] -# here... -for class_name in rename_classes: - _CLASS_RENAME[class_name] = f"{class_name}Sync" + # here... + for class_name in rename_classes: + _CLASS_RENAME[class_name] = f"{class_name}Sync" -_STRING_REPLACE["Auth"] = "AuthSync" + _STRING_REPLACE["Auth"] = "AuthSync" -src_dir_path = os.path.join(os.getcwd(), "ably") -dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") + src_dir_path = os.path.join(os.getcwd(), "ably") + dest_dir_path = os.path.join(os.getcwd(), "ably", "sync") -relevant_src_files = (set(find_files(src_dir_path, "*.py")) - - set(find_files(dest_dir_path, "*.py"))) + relevant_src_files = (set(find_files(src_dir_path, "*.py")) - + set(find_files(dest_dir_path, "*.py"))) -unasync_files(list(relevant_src_files), [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) + unasync_files(list(relevant_src_files), [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) -# Test files ============================================== + # Test files ============================================== -_TOKEN_REPLACE["asyncSetUp"] = "setUp" -_TOKEN_REPLACE["asyncTearDown"] = "tearDown" -_TOKEN_REPLACE["AsyncMock"] = "Mock" + _TOKEN_REPLACE["asyncSetUp"] = "setUp" + _TOKEN_REPLACE["asyncTearDown"] = "tearDown" + _TOKEN_REPLACE["AsyncMock"] = "Mock" -_TOKEN_REPLACE["_Channel__publish_request_body"] = "_ChannelSync__publish_request_body" -_TOKEN_REPLACE["_Http__client"] = "_HttpSync__client" + _TOKEN_REPLACE["_Channel__publish_request_body"] = "_ChannelSync__publish_request_body" + _TOKEN_REPLACE["_Http__client"] = "_HttpSync__client" -_IMPORTS_REPLACE["test.ably"] = "test.ably.sync" + _IMPORTS_REPLACE["test.ably"] = "test.ably.sync" -_STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' -_STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.AuthSync.request_token' -_STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' -_STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.HttpSync.post' -_STRING_REPLACE['httpx.AsyncClient.send'] = 'httpx.Client.send' -_STRING_REPLACE['ably.util.exceptions.AblyException.raise_for_response'] = \ - 'ably.sync.util.exceptions.AblyException.raise_for_response' -_STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRestSync.time' -_STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.AuthSync._timestamp' + _STRING_REPLACE['/../assets/testAppSpec.json'] = '/../../assets/testAppSpec.json' + _STRING_REPLACE['ably.rest.auth.Auth.request_token'] = 'ably.sync.rest.auth.AuthSync.request_token' + _STRING_REPLACE['ably.rest.auth.TokenRequest'] = 'ably.sync.rest.auth.TokenRequest' + _STRING_REPLACE['ably.rest.rest.Http.post'] = 'ably.sync.rest.rest.HttpSync.post' + _STRING_REPLACE['httpx.AsyncClient.send'] = 'httpx.Client.send' + _STRING_REPLACE['ably.util.exceptions.AblyException.raise_for_response'] = \ + 'ably.sync.util.exceptions.AblyException.raise_for_response' + _STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRestSync.time' + _STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.AuthSync._timestamp' -# round 1 -src_dir_path = os.path.join(os.getcwd(), "test", "ably") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") -src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), - os.path.join(os.getcwd(), "test", "ably", "utils.py")] + # round 1 + src_dir_path = os.path.join(os.getcwd(), "test", "ably") + dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") + src_files = [os.path.join(os.getcwd(), "test", "ably", "testapp.py"), + os.path.join(os.getcwd(), "test", "ably", "utils.py")] -unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) + unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path)]) -# round 2 -src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") -dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") -src_files = find_files(src_dir_path, "*.py") + # round 2 + src_dir_path = os.path.join(os.getcwd(), "test", "ably", "rest") + dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync", "rest") + src_files = find_files(src_dir_path, "*.py") -unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path, output_file_prefix="sync_")]) + unasync_files(src_files, [Rule(fromdir=src_dir_path, todir=dest_dir_path, output_file_prefix="sync_")]) diff --git a/pyproject.toml b/pyproject.toml index 1e0a1e78..d45199f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,6 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] timeout = 30 + +[tool.poetry.scripts] +unasync = 'ably.scripts.unasync:run' From 7a84a89a79b1a66d69b9246e68489088bd445e80 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Mon, 9 Oct 2023 19:40:09 +0530 Subject: [PATCH 47/52] Updated updating.md markdown file --- UPDATING.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index cddda023..271ff04b 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -101,10 +101,6 @@ def main(): if __name__ == "__main__": main() ``` -- To use old `AblyRest` class, but with `sync` style API. Import it as, -```python -from ably.sync import AblyRestSync as AblyRest -``` #### Publishing Messages From 5cfb920aa405043c179d96bc4d9b29b801f2d985 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Mon, 9 Oct 2023 19:43:44 +0530 Subject: [PATCH 48/52] Fixed indentation issues for unasync file --- ably/scripts/unasync.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/ably/scripts/unasync.py b/ably/scripts/unasync.py index c4c8e57f..93a6c901 100644 --- a/ably/scripts/unasync.py +++ b/ably/scripts/unasync.py @@ -231,20 +231,20 @@ def run(): _IMPORTS_REPLACE["ably"] = "ably.sync" rename_classes = [ - "AblyRest", - "Push", - "PushAdmin", - "Channel", - "Channels", - "Auth", - "Http", - "PaginatedResult", - "HttpPaginatedResponse" + "AblyRest", + "Push", + "PushAdmin", + "Channel", + "Channels", + "Auth", + "Http", + "PaginatedResult", + "HttpPaginatedResponse" ] # here... for class_name in rename_classes: - _CLASS_RENAME[class_name] = f"{class_name}Sync" + _CLASS_RENAME[class_name] = f"{class_name}Sync" _STRING_REPLACE["Auth"] = "AuthSync" @@ -277,7 +277,6 @@ def run(): _STRING_REPLACE['ably.rest.rest.AblyRest.time'] = 'ably.sync.rest.rest.AblyRestSync.time' _STRING_REPLACE['ably.rest.auth.Auth._timestamp'] = 'ably.sync.rest.auth.AuthSync._timestamp' - # round 1 src_dir_path = os.path.join(os.getcwd(), "test", "ably") dest_dir_path = os.path.join(os.getcwd(), "test", "ably", "sync") From 984ad07e2fcd7a360580bdc282c3c258f74548fc Mon Sep 17 00:00:00 2001 From: Owen Pearson Date: Mon, 9 Oct 2023 15:55:46 +0100 Subject: [PATCH 49/52] refactor(unasync): move static class names to top of file --- ably/scripts/unasync.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ably/scripts/unasync.py b/ably/scripts/unasync.py index 93a6c901..ed148742 100644 --- a/ably/scripts/unasync.py +++ b/ably/scripts/unasync.py @@ -4,6 +4,18 @@ import tokenize_rt +rename_classes = [ + "AblyRest", + "Push", + "PushAdmin", + "Channel", + "Channels", + "Auth", + "Http", + "PaginatedResult", + "HttpPaginatedResponse" +] + _TOKEN_REPLACE = { "__aenter__": "__enter__", "__aexit__": "__exit__", @@ -230,18 +242,6 @@ def run(): _IMPORTS_REPLACE["ably"] = "ably.sync" - rename_classes = [ - "AblyRest", - "Push", - "PushAdmin", - "Channel", - "Channels", - "Auth", - "Http", - "PaginatedResult", - "HttpPaginatedResponse" - ] - # here... for class_name in rename_classes: _CLASS_RENAME[class_name] = f"{class_name}Sync" From 40750793d5c7685536f05916f10b84bf60ac667b Mon Sep 17 00:00:00 2001 From: Owen Pearson Date: Mon, 9 Oct 2023 15:59:06 +0100 Subject: [PATCH 50/52] docs: update migration guide sync api notice --- UPDATING.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 271ff04b..fff56553 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -86,9 +86,8 @@ To see which versions of Python we test the SDK against, please look at our The 1.2.0 version introduces a breaking change, which changes the way of interacting with the SDK from synchronous to asynchronous, using [the `asyncio` foundational library](https://docs.python.org/3.7/library/asyncio.html) to provide support for `async`/`await` syntax. Because of this breaking change, every call that interacts with the Ably REST API must be refactored to this asynchronous way. -Important Update: -- If you want to keep using old synchronous style API, import `AblyRestSync` client instead. -- This is applicable only for Ably REST APIs. +For backwards compatibility, in ably-python 2.0.2 we have added a backwards compatible REST client so that you can still use the synchronous version of the REST interface if you are migrating forwards from version 1.1. +In order to use the synchronous variant, you can import the `AblyRestSync` constructor from `ably.sync`: ```python from ably.sync import AblyRestSync @@ -270,4 +269,4 @@ Must now be replaced with this new style, asynchronous form: ```python await client.time() await client.close() -``` \ No newline at end of file +``` From f17d3a6df4f33dd1c554179b1fc5224b5ea98032 Mon Sep 17 00:00:00 2001 From: Owen Pearson Date: Mon, 9 Oct 2023 16:12:51 +0100 Subject: [PATCH 51/52] docs: add sync api notice to README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index cd12649e..392b640a 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,9 @@ introduced by version 1.2.0. ### Using the Rest API +> [!NOTE] +> Please note that since version 2.0.2 we also provide a synchronous variant of the REST interface which is can be accessed as `from ably.sync import AblyRestSync`. + All examples assume a client and/or channel has been created in one of the following ways: With closing the client manually: From b6b463bf29392b0abe6566311bd212a1e56853b8 Mon Sep 17 00:00:00 2001 From: Owen Pearson Date: Mon, 9 Oct 2023 16:32:30 +0100 Subject: [PATCH 52/52] build: include generated files in published package --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d45199f7..042de6e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries :: Python Modules", ] +include = [ + 'ably/**/*.py' +] [tool.poetry.dependencies] python = "^3.7"