From 5c20310b4b78d7a0594705acb16187776e6fd54f Mon Sep 17 00:00:00 2001 From: rishiraj Date: Mon, 15 Apr 2024 03:30:53 +0530 Subject: [PATCH] add save load feature --- requirements.txt | 2 ++ setup.py | 7 ++++--- spanking/main.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3edcb95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +jax==0.4.26 +sentence-transformers==2.6.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 1443d3e..c4dfaec 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,9 @@ with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() +with open('requirements.txt') as fh: + requirements = fh.read().splitlines() + setuptools.setup( name="spanking", version=spanking_version, @@ -45,7 +48,5 @@ ], python_requires=">=3.9", entry_points={"console_scripts": ["spanking = spanking.main:main"]}, - install_requires=[ - "numpy >= 1.26.4", - ], + install_requires=requirements, ) diff --git a/spanking/main.py b/spanking/main.py index 9c86db1..37fc3ac 100644 --- a/spanking/main.py +++ b/spanking/main.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +import pickle from sentence_transformers import SentenceTransformer class VectorDB: @@ -37,6 +38,15 @@ def search(self, query, top_k=5): top_indices = jnp.argsort(similarities)[-top_k:][::-1] return [(self.texts[i], similarities[i]) for i in top_indices] + def save(self, file_path): + with open(file_path, 'wb') as file: + pickle.dump(self, file) + + @staticmethod + def load(file_path): + with open(file_path, 'rb') as file: + return pickle.load(file) + def __len__(self): return len(self.texts)