diff --git a/tests/test_llmlingua.py b/tests/test_llmlingua.py index 485ad50..fad44c8 100644 --- a/tests/test_llmlingua.py +++ b/tests/test_llmlingua.py @@ -55,7 +55,10 @@ class LLMLinguaTester(unittest.TestCase): def __init__(self, *args, **kwargs): super(LLMLinguaTester, self).__init__(*args, **kwargs) import nltk - nltk.download('punkt') + try: + nltk.data.find('tokenizers/punkt') + except LookupError: + nltk.download('punkt') self.llmlingua = PromptCompressor("lgaalves/gpt2-dolly", device_map="cpu") def test_general_compress_prompt(self): diff --git a/tests/test_longllmlingua.py b/tests/test_longllmlingua.py index 07f1323..edefb17 100644 --- a/tests/test_longllmlingua.py +++ b/tests/test_longllmlingua.py @@ -59,7 +59,10 @@ class LongLLMLinguaTester(unittest.TestCase): def __init__(self, *args, **kwargs): super(LongLLMLinguaTester, self).__init__(*args, **kwargs) import nltk - nltk.download('punkt') + try: + nltk.data.find('tokenizers/punkt') + except LookupError: + nltk.download('punkt') self.llmlingua = PromptCompressor("lgaalves/gpt2-dolly", device_map="cpu") def test_general_compress_prompt(self):