Skip to content

Commit

Permalink
Merge pull request #13 from ThoughtRiver/feature-add-lru-cached-reader
Browse files Browse the repository at this point in the history
Feature: Lru-Cached reader
  • Loading branch information
DomHudson authored Dec 16, 2019
2 parents 17ff54a + 89d65d0 commit 152fd37
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 212 deletions.
7 changes: 2 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ cache: pip
python:
- 3.6
install:
- pip install --upgrade pip
- pip install . && pip install flake8
- pip install .[develop]
before_script:
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- flake8 . --count
script:
- pytest
notifications:
Expand Down
32 changes: 24 additions & 8 deletions lmdb_embeddings/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,28 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""


import functools
import lmdb
from lmdb_embeddings import exceptions
from lmdb_embeddings.serializers import PickleSerializer


class LmdbEmbeddingsReader:

MAX_READERS = 2048

def __init__(self, path, unserializer = PickleSerializer.unserialize, **kwargs):
""" Constructor.
:return void
:param str path:
:param callable unserializer:
:return void:
"""
self.unserializer = unserializer
self.environment = lmdb.open(
path,
readonly = True,
max_readers = 2048,
max_readers = self.MAX_READERS,
max_spare_txns = 2,
lock = kwargs.pop('lock', False),
**kwargs
Expand All @@ -44,15 +48,27 @@ def __init__(self, path, unserializer = PickleSerializer.unserialize, **kwargs):
def get_word_vector(self, word):
""" Fetch a word from the LMDB database.
:raises lmdb_embeddings.exceptions.MissingWordError
:return np.array
:param str word:
:raises lmdb_embeddings.exceptions.MissingWordError:
:return np.array:
"""
with self.environment.begin() as transaction:
word_vector = transaction.get(word.encode(encoding = 'UTF-8'))

if word_vector is None:
raise exceptions.MissingWordError(
'"%s" does not exist in the database.' % word
)
raise exceptions.MissingWordError('"%s" does not exist in the database.' % word)

return self.unserializer(word_vector)


class LruCachedLmdbEmbeddingsReader(LmdbEmbeddingsReader):

@functools.lru_cache(maxsize = 50000)
def get_word_vector(self, word):
""" Fetch a word from the LMDB database.
:param str word:
:raises lmdb_embeddings.exceptions.MissingWordError:
:return np.array:
"""
return super().get_word_vector(word)
26 changes: 11 additions & 15 deletions lmdb_embeddings/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ class PickleSerializer:
def serialize(vector):
""" Serializer a vector using pickle.
:return bytes
:param np.array vector:
:return bytes:
"""
return pickletools.optimize(
pickle.dumps(vector, pickle.HIGHEST_PROTOCOL)
)
return pickletools.optimize(pickle.dumps(vector, pickle.HIGHEST_PROTOCOL))

@staticmethod
def unserialize(serialized_vector):
""" Unserialize a vector using pickle.
:return np.array
:param bytes serialized_vector:
:return np.array:
"""
return pickle.loads(serialized_vector)

Expand All @@ -52,20 +52,16 @@ class MsgpackSerializer:
def serialize(vector):
""" Serializer a vector using msgpack.
:return bytes
:param np.array vector:
:return bytes:
"""
return msgpack.packb(
vector,
default = msgpack_numpy.encode
)
return msgpack.packb(vector, default = msgpack_numpy.encode)

@staticmethod
def unserialize(serialized_vector):
""" Unserialize a vector using msgpack.
:return np.array
:param bytes serialized_vector:
:return np.array:
"""
return msgpack.unpackb(
serialized_vector,
object_hook = msgpack_numpy.decode
)
return msgpack.unpackb(serialized_vector, object_hook = msgpack_numpy.decode)
39 changes: 0 additions & 39 deletions lmdb_embeddings/tests/base.py

This file was deleted.

128 changes: 128 additions & 0 deletions lmdb_embeddings/tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
LMDB Embeddings - Fast word vectors with little memory usage in Python.
[email protected]
Copyright (C) 2018 ThoughtRiver Limited
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""


import os

import numpy as np
import pytest

from lmdb_embeddings import exceptions
from lmdb_embeddings.reader import LmdbEmbeddingsReader
from lmdb_embeddings.reader import LruCachedLmdbEmbeddingsReader
from lmdb_embeddings.serializers import MsgpackSerializer
from lmdb_embeddings.writer import LmdbEmbeddingsWriter


class TestEmbeddings:

def test_write_embeddings(self, tmp_path):
""" Ensure we can write embeddings to disk without error.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)

LmdbEmbeddingsWriter([
('the', np.random.rand(10)),
('is', np.random.rand(10))
]).write(directory_path)

assert os.listdir(directory_path)

def test_write_embeddings_generator(self, tmp_path):
""" Ensure we can a generator of embeddings to disk without error.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)
embeddings_generator = ((str(i), np.random.rand(10)) for i in range(10))

LmdbEmbeddingsWriter(embeddings_generator).write(directory_path)

assert os.listdir(directory_path)

@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
def test_reading_embeddings(self, tmp_path, reader_class):
""" Ensure we can retrieve embeddings from the database.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)

the_vector = np.random.rand(10)
LmdbEmbeddingsWriter([
('the', the_vector),
('is', np.random.rand(10))
]).write(directory_path)

assert reader_class(directory_path).get_word_vector('the').tolist() == the_vector.tolist()

@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
def test_missing_word_error(self, tmp_path, reader_class):
""" Ensure a MissingWordError exception is raised if the word does not exist in the
database.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)

LmdbEmbeddingsWriter([
('the', np.random.rand(10)),
('is', np.random.rand(10))
]).write(directory_path)

reader = reader_class(directory_path)

with pytest.raises(exceptions.MissingWordError):
reader.get_word_vector('unknown')

def test_word_too_long(self, tmp_path):
""" Ensure we do not get an exception if attempting to write aword longer than LMDB's
maximum key size.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)

LmdbEmbeddingsWriter([('a' * 1000, np.random.rand(10))]).write(directory_path)

@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
def test_msgpack_serialization(self, tmp_path, reader_class):
""" Ensure we can save and retrieve embeddings serialized with msgpack.
:param pathlib.PosixPath tmp_path:
:return void:
"""
directory_path = str(tmp_path)
the_vector = np.random.rand(10)

LmdbEmbeddingsWriter(
[('the', the_vector), ('is', np.random.rand(10))],
serializer = MsgpackSerializer.serialize
).write(directory_path)

reader = reader_class(directory_path, unserializer = MsgpackSerializer.unserialize)
assert reader.get_word_vector('the').tolist() == the_vector.tolist()
Loading

0 comments on commit 152fd37

Please sign in to comment.