forked from annypan/ilm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_sorting.py
69 lines (60 loc) · 2.98 KB
/
test_sorting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import unittest
import pickle
from sorting import load_pickle, filter_words, filter_ngrams, filter_sentences, filter_paragraphs, filter_documents, count_mask_types
from ilm.mask.hierarchical import MaskHierarchicalType
class TestSorting(unittest.TestCase):
def test_pickle_loading(self):
masks = load_pickle('sample_pickle.pkl', 20)
self.assertTrue(len(masks) == 20)
def test_count_mask_types(self):
masks = load_pickle('sample_pickle.pkl', 20)
self.assertEqual(count_mask_types(masks[0][1][14]), [0, 1, 1, 2, 2])
def test_filtering_words(self):
masks = load_pickle('sample_pickle.pkl', 20)
result_array = []
filter_words(masks, result_array)
for document in result_array:
for masking in document[1]:
for mask in masking:
self.assertTrue(mask[0] == MaskHierarchicalType.WORD)
def test_filtering_ngrams(self):
masks = load_pickle('sample_pickle.pkl', 20)
result_array = []
filter_ngrams(masks, result_array)
for document in result_array:
for masking in document[1]:
for mask in masking:
self.assertTrue(mask[0] == MaskHierarchicalType.WORD or mask[0] == MaskHierarchicalType.NGRAM)
def test_filtering_sentences(self):
masks = load_pickle('sample_pickle.pkl', 20)
result_array = []
filter_sentences(masks, result_array)
for document in result_array:
for masking in document[1]:
for mask in masking:
self.assertTrue(mask[0] == MaskHierarchicalType.WORD or mask[0] == MaskHierarchicalType.NGRAM or mask[0] == MaskHierarchicalType.SENTENCE)
def test_filtering_paragraphs(self):
masks = load_pickle('sample_pickle.pkl', 20)
result_array = []
filter_paragraphs(masks, result_array)
for document in result_array:
for masking in document[1]:
for mask in masking:
self.assertTrue(mask[0] == MaskHierarchicalType.WORD or mask[0] == MaskHierarchicalType.NGRAM or mask[0] == MaskHierarchicalType.SENTENCE or mask[0] == MaskHierarchicalType.PARAGRAPH)
def test_number_of_maskings_remain_same(self):
number_of_documents = 20
masks = load_pickle('sample_pickle.pkl', number_of_documents)
word_masks = []
ngram_masks = []
sentence_masks = []
paragraph_masks = []
document_masks = []
filter_words(masks, word_masks)
filter_ngrams(masks, ngram_masks)
filter_sentences(masks, sentence_masks)
filter_paragraphs(masks, paragraph_masks)
filter_documents(masks, document_masks)
for didx in range(number_of_documents):
self.assertEqual(len(word_masks[didx][1]) + len(ngram_masks[didx][1]) + len(sentence_masks[didx][1]) + len(paragraph_masks[didx][1]) + len(document_masks[didx][1]), len(masks[didx][1]))
if __name__ == '__main__':
unittest.main()