Skip to content

Commit

Permalink
Add tests for RestrictedSqliteDict
Browse files Browse the repository at this point in the history
  • Loading branch information
mstopa-splunk committed Oct 29, 2024
1 parent 1035394 commit 27b7786
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
22 changes: 5 additions & 17 deletions package/etc/pylib/sqlite_utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,25 @@
import builtins
import io
import pickle
from base64 import b64decode
from sqlitedict import SqliteDict

safe_builtins = {
'range',
'complex',
'set',
'frozenset',
'slice',
}


class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow safe classes from builtins.
if module == "builtins" and name in safe_builtins:
return getattr(builtins, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
"""Override pickle.Unpickler.find_class() to prevent deserialization of class instances."""
raise pickle.UnpicklingError("Class deserialization is disabled")


def restricted_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()

def restricted_decode(obj):
"""Overwrite sqlitedict.decode to prevent code injection."""
"""Overwrite sqlitedict.decode() to prevent code injection."""
return restricted_loads(bytes(obj))

def restricted_decode_key(key):
"""Overwrite sqlitedict.decode_key to prevent code injection."""
"""Overwrite sqlitedict.decode_key() to prevent code injection."""
return restricted_loads(b64decode(key.encode("ascii")))


Expand Down
38 changes: 37 additions & 1 deletion tests/test_name_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# https://opensource.org/licenses/BSD-2-Clause

import datetime
import pickle
import random
import re
import tempfile
import time

from jinja2 import Environment
Expand All @@ -16,6 +18,7 @@
from .sendmessage import sendsingle
from .splunkutils import splunk_single
from package.etc.pylib.parser_source_cache import ip2int, int2ip
from package.etc.pylib.sqlite_utils import RestrictedSqliteDict

env = Environment()

Expand Down Expand Up @@ -73,4 +76,37 @@ def test_ipv4_utils():
@pytest.mark.name_cache
def test_ipv6_utils():
ip = generate_random_ipv6()
assert ip == int2ip(ip2int(ip))
assert ip == int2ip(ip2int(ip))

@pytest.mark.name_cache
def test_RestrictedSqliteDict_stores_and_retrieves_string():
with tempfile.NamedTemporaryFile(delete=True) as temp_db_file:
cache = RestrictedSqliteDict(f"{temp_db_file.name}.db")
cache["key"] = "value"
cache.commit()
cache.close()

cache = RestrictedSqliteDict(f"{temp_db_file.name}.db")
assert cache["key"] == "value"
cache.close()

@pytest.mark.name_cache
def test_RestrictedSqliteDict_prevents_code_injection():
class InjectionTestClass:
def __reduce__(self):
import os
return os.system, ('touch pwned.txt',)

with tempfile.NamedTemporaryFile(delete=True) as temp_db_file:
# Initialize the RestrictedSqliteDict and insert an 'injected' object
cache = RestrictedSqliteDict(f"{temp_db_file.name}.db")
cache["key"] = InjectionTestClass()
cache.commit()
cache.close()

# Re-open cache and attempt to deserialize 'injected' object
# Expecting UnpicklingError due to RestrictedSqliteDict restrictions
cache = RestrictedSqliteDict(f"{temp_db_file.name}.db")
with pytest.raises(pickle.UnpicklingError):
_ = cache["key"]
cache.close()

0 comments on commit 27b7786

Please sign in to comment.