Skip to content

Commit

Permalink
fix multiple object failure, plus a test to prove it.
Browse files Browse the repository at this point in the history
  • Loading branch information
defrex committed Jul 9, 2013
1 parent ec8b3e6 commit d73c270
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
8 changes: 5 additions & 3 deletions src/django_fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def to_python(self, value):
else:
decrypt_value = binascii.a2b_hex(value[len(self.prefix):])
return force_unicode(
self.cipher.decrypt(
decrypt_value
).split('\0')[0]
self.cipher.decrypt(decrypt_value).split('\0')[0]
)
return value

Expand All @@ -102,6 +100,10 @@ def get_db_prep_value(self, value, connection=None, prepared=False):
value += "\0" + ''.join([random.choice(string.printable)
for index in range(padding-1)])
if self.block_type:
self.cipher = self.cipher_object.new(
settings.SECRET_KEY[:32],
getattr(self.cipher_object, self.block_type),
self.iv)
value = self.prefix + binascii.b2a_hex(self.iv + self.cipher.encrypt(value))
else:
value = self.prefix + binascii.b2a_hex(self.cipher.encrypt(value))
Expand Down
33 changes: 26 additions & 7 deletions src/django_fields/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PickleObject(models.Model):
data = PickleField()


class EmailObject(models.Model):
class EmailObject(models.Model):
max_email = 255
email = EncryptedEmailField()

Expand Down Expand Up @@ -87,7 +87,7 @@ def test_encryption(self):
encrypted_password = self._get_encrypted_password(obj.id)
self.assertNotEqual(encrypted_password, password)
self.assertTrue(encrypted_password.startswith('$AES$'))

def test_encryption_w_cipher(self):
"""
Test that the database values are actually encrypted when using
Expand All @@ -104,6 +104,25 @@ def test_encryption_w_cipher(self):
self.assertNotEqual(encrypted_password, password)
self.assertTrue(encrypted_password.startswith('$AES$MODE_CBC$'))

def test_multiple_encryption_w_cipher(self):
"""
Test that the database values are actually encrypted when using
non-default cipher types.
"""
password = 'this is a password!!' # 20 chars
obj = CipherEncObject(password = password)
obj.save()
# The value from the retrieved object should be the same...
obj = CipherEncObject.objects.get(id=obj.id)
self.assertEqual(password, obj.password)

password = 'another password!!' # 20 chars
obj = CipherEncObject(password = password)
obj.save()
# The value from the retrieved object should be the same...
obj = CipherEncObject.objects.get(id=obj.id)
self.assertEqual(password, obj.password)

def test_max_field_length(self):
password = 'a' * EncObject.max_password
obj = EncObject(password = password)
Expand All @@ -122,7 +141,7 @@ def test_UTF8(self):
obj.save()
obj = EncObject.objects.get(id=obj.id)
self.assertEqual(password, obj.password)

def test_consistent_encryption(self):
"""
The same password should not encrypt the same way twice.
Expand All @@ -134,7 +153,7 @@ def test_consistent_encryption(self):
for pwd_length in range(1,21): # 1-20 inclusive
enc_pwd_1, enc_pwd_2 = self._get_two_passwords(pwd_length)
self.assertNotEqual(enc_pwd_1, enc_pwd_2)

def test_minimum_padding(self):
"""
There should always be at least two chars of padding.
Expand Down Expand Up @@ -163,7 +182,7 @@ def _get_encrypted_password(self, id):
passwords = map(lambda x: x[0], cursor.fetchall())
self.assertEqual(len(passwords), 1) # only one
return passwords[0]

def _get_encrypted_password_cipher(self, id):
cursor = connection.cursor()
cursor.execute("select password from django_fields_cipherencobject where id = %s", [id,])
Expand Down Expand Up @@ -352,7 +371,7 @@ def test_UTF8(self):
obj.save()
obj = EmailObject.objects.get(id=obj.id)
self.assertEqual(email, obj.email)

def test_consistent_encryption(self):
"""
The same password should not encrypt the same way twice.
Expand All @@ -364,7 +383,7 @@ def test_consistent_encryption(self):
for email_length in range(1,21): # 1-20 inclusive
enc_email_1, enc_email_2 = self._get_two_emails(email_length)
self.assertNotEqual(enc_email_1, enc_email_2)

def test_minimum_padding(self):
"""
There should always be at least two chars of padding.
Expand Down

0 comments on commit d73c270

Please sign in to comment.