Skip to content

Commit

Permalink
inference test improved
Browse files Browse the repository at this point in the history
  • Loading branch information
maximtrp committed Apr 6, 2021
1 parent b70e254 commit 0e9fa36
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_btm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ def test_btm_class(self):
biterms = btm.get_biterms(docs_vec)

LOGGER.info('Modeling started')
topics_num = 8
model = btm.BTM(
X, vocabulary, seed=12321, T=8, W=vocabulary.size, M=20, alpha=50/8, beta=0.01)
X, vocabulary, seed=12321, T=topics_num, W=vocabulary.size,
M=20, alpha=50/topics_num, beta=0.01)
# t1 = time.time()
model.fit(biterms, iterations=20)
# LOGGER.info(model.theta_)
# t2 = time.time()
# LOGGER.info(t2 - t1)
self.assertIsInstance(model.matrix_topics_words_, np.ndarray)
self.assertTupleEqual(
model.matrix_topics_words_.shape, (8, vocabulary.size))
model.matrix_topics_words_.shape, (topics_num, vocabulary.size))
LOGGER.info('Modeling finished')
# top_words = btm.get_top_topic_words(model)
# LOGGER.info(top_words)
Expand All @@ -46,7 +48,8 @@ def test_btm_class(self):
LOGGER.info('Model saving finished')

LOGGER.info('Inference started')
p_zd = model.transform(docs_vec)
p_zd = model.transform(docs_vec[:1000])
self.assertTupleEqual(p_zd.shape, (1000, topics_num))
# LOGGER.info(p_zd)
LOGGER.info('Inference "sum_b" finished')
p_zd = model.transform(docs_vec, infer_type='sum_w')
Expand Down

0 comments on commit 0e9fa36

Please sign in to comment.