Skip to content

Commit

Permalink
Post review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
seba-aln committed Oct 2, 2023
1 parent 119857b commit 8cac1f3
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 168 deletions.
38 changes: 10 additions & 28 deletions pubnub/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,26 +112,13 @@ def decrypt(self, key, file):

class PubNubCryptoModule(PubNubCrypto):
FALLBACK_CRYPTOR_ID: str = '0000'
CRYPTOR_VERSION: str = 1
cryptor_map = {}
default_cryptor_id: str

def __init__(self, cryptor_map: Dict[str, PubNubCryptor], default_cryptor: PubNubCryptor):
self.cryptor_map = cryptor_map
self.default_cryptor_id = default_cryptor.CRYPTOR_ID

def register_cryptor(self, cryptor_id, cryptor_instance):
if len(cryptor_id) != 4:
raise PubNubException('Malformed cryptor_id.')

if cryptor_id in self.cryptor_map.keys():
raise PubNubException('Cryptor_id already in use')

if not isinstance(cryptor_instance, PubNubCrypto):
raise PubNubException('Invalid cryptor instance')

self.cryptor_map[cryptor_id] = cryptor_instance

def _validate_cryptor_id(self, cryptor_id: str) -> str:
cryptor_id = cryptor_id or self.default_cryptor_id

Expand All @@ -155,19 +142,13 @@ def encrypt(self, message: str, cryptor_id: str = None) -> str:
def decrypt(self, input):
data = b64decode(input)
header = self.decode_header(data)

if header:
cryptor_id = header['cryptor_id']
cryptor_version = header['cryptor_ver']
payload = CryptorPayload(data=data[header['length']:], cryptor_data=header['cryptor_data'])
if not header:
cryptor_id = self.FALLBACK_CRYPTOR_ID
cryptor_version = self.CRYPTOR_VERSION
payload = CryptorPayload(data=data)

if cryptor_id not in self.cryptor_map.keys() or cryptor_version > self.CRYPTOR_VERSION:
raise PubNubException('unknown cryptor error')

message = self._get_cryptor(cryptor_id).decrypt(payload)
try:
return json.loads(message)
Expand All @@ -189,27 +170,25 @@ def decrypt_file(self, file_data):
header = self.decode_header(file_data)
if header:
cryptor_id = header['cryptor_id']
cryptor_version = header['cryptor_ver']
payload = CryptorPayload(data=file_data[header['length']:], cryptor_data=header['cryptor_data'])
else:
cryptor_id = self.FALLBACK_CRYPTOR_ID
cryptor_version = self.CRYPTOR_VERSION
payload = CryptorPayload(data=file_data)

if cryptor_id not in self.cryptor_map.keys() or cryptor_version > self.CRYPTOR_VERSION:
if cryptor_id not in self.cryptor_map.keys():
raise PubNubException('unknown cryptor error')

return self._get_cryptor(cryptor_id).decrypt(payload, binary_mode=True)

def encode_header(self, cryptor_ver=None, cryptor_id: str = None, cryptor_data: any = None) -> str:
def encode_header(self, cryptor_id: str = None, cryptor_data: any = None) -> str:
if cryptor_id == self.FALLBACK_CRYPTOR_ID:
return b''
if cryptor_data and len(cryptor_data) > 65535:
raise PubNubException('Cryptor data is too long')
cryptor_id = self._validate_cryptor_id(cryptor_id)
cryptor_ver = cryptor_ver or self.CRYPTOR_VERSION

sentinel = b'PNED'
version = cryptor_ver.to_bytes(1, byteorder='big')
version = CryptoHeader.header_ver.to_bytes(1, byteorder='big')
crid = bytes(cryptor_id, 'utf-8')

if cryptor_data:
Expand All @@ -234,7 +213,10 @@ def decode_header(self, header: bytes) -> Union[None, CryptoHeader]:
return False

try:
cryptor_ver = header[4]
header_version = header[4]
if header_version > CryptoHeader.header_ver:
raise PubNubException('unknown cryptor error')

cryptor_id = header[5:9].decode()
crlen = header[9]
if crlen < 255:
Expand All @@ -245,9 +227,9 @@ def decode_header(self, header: bytes) -> Union[None, CryptoHeader]:
cryptor_data = header[12:12 + crlen]
hlen = 12 + crlen

return CryptoHeader(sentinel=sentinel, cryptor_ver=cryptor_ver, cryptor_id=cryptor_id,
return CryptoHeader(sentinel=sentinel, header_ver=header_version, cryptor_id=cryptor_id,
cryptor_data=cryptor_data, length=hlen)
except Exception:
except IndexError:
raise PubNubException('decryption error')


Expand Down
24 changes: 9 additions & 15 deletions pubnub/crypto_core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import hashlib
import json
import random
import os
import secrets

from abc import abstractmethod
from Cryptodome.Cipher import AES
from Cryptodome.Util.Padding import pad, unpad


class PubNubCrypto:
Expand All @@ -22,7 +23,7 @@ def decrypt(self, key, msg):

class CryptoHeader(dict):
sentinel: str
cryptor_ver: int
header_ver: int = 1
cryptor_id: str
cryptor_data: any
length: any
Expand All @@ -35,7 +36,6 @@ class CryptorPayload(dict):

class PubNubCryptor:
CRYPTOR_ID: str
CRYPTOR_VERSION: int = 1

@abstractmethod
def encrypt(self, data: bytes) -> CryptorPayload:
Expand Down Expand Up @@ -125,39 +125,33 @@ def get_secret(self, key):

class PubNubAesCbcCryptor(PubNubCryptor):
CRYPTOR_ID = 'ACRH'
CRYPTOR_VERSION: int = 1
mode = AES.MODE_CBC

def __init__(self, cipher_key):
self.cipher_key = cipher_key

def get_initialization_vector(self) -> bytes:
return os.urandom(16)
return secrets.token_bytes(16)

def get_secret(self, key) -> str:
return hashlib.sha256(key.encode("utf-8")).digest()

def pad(self, msg: bytes, block_size=AES.block_size) -> bytes:
padding = block_size - (len(msg) % block_size)
return msg + bytes(chr(padding) * padding, 'utf-8')

def depad(self, msg: bytes) -> bytes:
return msg[:-msg[-1]]

def encrypt(self, data: bytes, key=None) -> CryptorPayload:
key = key or self.cipher_key
secret = self.get_secret(key)
iv = self.get_initialization_vector()
cipher = AES.new(secret, mode=self.mode, iv=iv)
encrypted = cipher.encrypt(self.pad(data))
encrypted = cipher.encrypt(pad(data, AES.block_size))
return CryptorPayload(data=encrypted, cryptor_data=iv)

def decrypt(self, payload: CryptorPayload, key=None, binary_mode: bool = False):
key = key or self.cipher_key
secret = self.get_secret(key)
iv = payload['cryptor_data']

cipher = AES.new(secret, mode=self.mode, iv=iv)

if binary_mode:
return self.depad(cipher.decrypt(payload['data']))
return unpad(cipher.decrypt(payload['data']), AES.block_size)
else:
return self.depad(cipher.decrypt(payload['data'])).decode()
return unpad(cipher.decrypt(payload['data']), AES.block_size).decode()
121 changes: 0 additions & 121 deletions tests/acceptance/encryption/cryptor-module.feature

This file was deleted.

5 changes: 3 additions & 2 deletions tests/acceptance/encryption/steps/when_steps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from base64 import b64decode
from behave import when
from tests.acceptance.encryption.environment import PNContext, get_crypto_module, get_asset_path
from pubnub.exceptions import PubNubException


@when("I decrypt '{filename}' file")
Expand All @@ -11,7 +12,7 @@ def step_impl(context: PNContext, filename):
file_bytes = file_handle.read()
crypto.decrypt_file(file_bytes)
context.outcome = 'success'
except Exception as e:
except PubNubException as e:
context.outcome = str(e).replace('None: ', '')


Expand All @@ -36,5 +37,5 @@ def step_impl(context: PNContext, filename, file_mode):
file_bytes = file_handle.read()
context.decrypted_file = crypto.decrypt_file(file_bytes)
context.outcome = 'success'
except Exception as e:
except PubNubException as e:
context.outcome = str(e).replace('None: ', '')
21 changes: 19 additions & 2 deletions tests/unit/test_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_header_encoder(self):
assert b'PNED\x01ACRH\x00' == header

cryptor_data = b'\x21'
header = crypto.encode_header(cryptor_ver=1, cryptor_data=cryptor_data)
header = crypto.encode_header(cryptor_data=cryptor_data)
assert b'PNED\x01ACRH\x01' + cryptor_data == header

cryptor_data = b'\x21' * 255
Expand All @@ -108,7 +108,7 @@ def test_header_encoder(self):
def test_header_decoder(self):
crypto = AesCbcCryptoModule('myCipherKey', True)
header = crypto.decode_header(b'PNED\x01ACRH\x00')
assert header['cryptor_ver'] == 1
assert header['header_ver'] == 1
assert header['cryptor_id'] == 'ACRH'
assert header['cryptor_data'] == b''

Expand Down Expand Up @@ -188,3 +188,20 @@ def test_encrypt_module_decrypt_legacy_random_iv(self):
decrypted = crypto.decrypt(self.cipher_key, encrypted)

assert decrypted == original_message

def test_php_encrypted_crosscheck(self):
crypto = AesCbcCryptoModule(self.cipher_key, False)
phpmess = "KGc+SNJD7mIveY+KNIL/L9ZzAjC0dCJCju+HXRwSW2k="
decrypted = crypto.decrypt(phpmess)
assert decrypted == 'PHP can backwards Legacy static'

crypto = AesCbcCryptoModule(self.cipher_key, True)
phpmess = "PXjHv0L05kgj0mqIE9s7n4LDPrLtjnfamMoHyiMoL0R1uzSMsYp7dDfqEWrnoaqS"
decrypted = crypto.decrypt(phpmess)
assert decrypted == 'PHP can backwards Legacy random'

crypto = AesCbcCryptoModule(self.cipher_key, True)
phpmess = "UE5FRAFBQ1JIEHvl3cY3RYsHnbKm6VR51XG/Y7HodnkumKHxo+mrsxbIjZvFpVuILQ0oZysVwjNsDNMKiMfZteoJ8P1/" \
"mvPmbuQKLErBzS2l7vEohCwbmAJODPR2yNhJGB8989reTZ7Y7Q=="
decrypted = crypto.decrypt(phpmess)
assert decrypted == 'PHP can into space with headers and aes cbc and other shiny stuff'

0 comments on commit 8cac1f3

Please sign in to comment.