Skip to content

Commit

Permalink
Merge pull request #5 from rishiraj/development
Browse files Browse the repository at this point in the history
add save load feature
  • Loading branch information
rishiraj authored Apr 14, 2024
2 parents ed3c547 + 5c20310 commit e7d178c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
jax==0.4.26
sentence-transformers==2.6.1
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

with open('requirements.txt') as fh:
requirements = fh.read().splitlines()

setuptools.setup(
name="spanking",
version=spanking_version,
Expand All @@ -45,7 +48,5 @@
],
python_requires=">=3.9",
entry_points={"console_scripts": ["spanking = spanking.main:main"]},
install_requires=[
"numpy >= 1.26.4",
],
install_requires=requirements,
)
10 changes: 10 additions & 0 deletions spanking/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
import pickle
from sentence_transformers import SentenceTransformer

class VectorDB:
Expand Down Expand Up @@ -37,6 +38,15 @@ def search(self, query, top_k=5):
top_indices = jnp.argsort(similarities)[-top_k:][::-1]
return [(self.texts[i], similarities[i]) for i in top_indices]

def save(self, file_path):
with open(file_path, 'wb') as file:
pickle.dump(self, file)

@staticmethod
def load(file_path):
with open(file_path, 'rb') as file:
return pickle.load(file)

def __len__(self):
return len(self.texts)

Expand Down

0 comments on commit e7d178c

Please sign in to comment.