diff --git a/src/backend/marsha/settings.py b/src/backend/marsha/settings.py index 97655bbd74..df24cf01df 100644 --- a/src/backend/marsha/settings.py +++ b/src/backend/marsha/settings.py @@ -171,7 +171,7 @@ class Base(Configuration): # } CHANNEL_LAYERS = { "default": { - "BACKEND": "channels_redis.core.RedisChannelLayer", + "BACKEND": "marsha.websocket.layers.JsonRedisChannelLayer", "CONFIG": { "hosts": values.ListValue( [("redis", 6379)], environ_name="REDIS_HOST", environ_prefix=None diff --git a/src/backend/marsha/websocket/layers.py b/src/backend/marsha/websocket/layers.py new file mode 100644 index 0000000000..855886ef66 --- /dev/null +++ b/src/backend/marsha/websocket/layers.py @@ -0,0 +1,36 @@ +"""Layers used by django channels""" +import json +import random + +from django.core.serializers.json import DjangoJSONEncoder + +from channels_redis.core import RedisChannelLayer + + +class JsonRedisChannelLayer(RedisChannelLayer): + """Use json to serialize and deserialize messages.""" + + def serialize(self, message): + """ + Serializes message in json. + """ + value = bytes(json.dumps(message, cls=DjangoJSONEncoder), encoding="utf-8") + if self.crypter: + value = self.crypter.encrypt(value) + + # As we use an sorted set to expire messages we need to guarantee uniqueness, + # with 12 bytes. + random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big") + return random_prefix + value + + def deserialize(self, message): + """ + Deserializes from a byte string. + """ + # Removes the random prefix + message = message[12:] + message = message.decode("utf-8") + + if self.crypter: + message = self.crypter.decrypt(message, self.expiry + 10) + return json.loads(message) diff --git a/src/backend/marsha/websocket/tests/test_layers.py b/src/backend/marsha/websocket/tests/test_layers.py new file mode 100644 index 0000000000..0373793081 --- /dev/null +++ b/src/backend/marsha/websocket/tests/test_layers.py @@ -0,0 +1,31 @@ +"""Test marsha layers use by django channels.""" +from django.test import TestCase + +from marsha.websocket.layers import JsonRedisChannelLayer + + +class JsonRedisChannelLayerTest(TestCase): + """Test serialize and deserialize.""" + + def test_serialize_message(self): + """ + Test default serialization method + """ + message = {"a": True, "b": None, "c": {"d": []}} + channel_layer = JsonRedisChannelLayer() + serialized = channel_layer.serialize(message) + self.assertIsInstance(serialized, bytes) + self.assertEqual(serialized[12:], b'{"a": true, "b": null, "c": {"d": []}}') + + def test_deserialize_message(self): + """ + Test default deserialization method + """ + message = ( + b'[\x85\xf8\xdeY\xe5\xa3}is\x0f3{"a": true, "b": null, "c": {"d": []}}' + ) + channel_layer = JsonRedisChannelLayer() + deserialized = channel_layer.deserialize(message) + + self.assertIsInstance(deserialized, dict) + self.assertEqual(deserialized, {"a": True, "b": None, "c": {"d": []}})