From 70ae6d7b90f0db905ca634fc1bae43b80fc8d672 Mon Sep 17 00:00:00 2001 From: Maksim Terpilowski Date: Thu, 19 Aug 2021 17:24:52 +0400 Subject: [PATCH] renyi entropy added --- docs/source/bitermplus.metrics.rst | 3 +- setup.py | 6 +-- src/bitermplus/_metrics.pyx | 83 +++++++++++++++++++++++++++++- tests/test_btm.py | 6 +++ 4 files changed, 93 insertions(+), 5 deletions(-) diff --git a/docs/source/bitermplus.metrics.rst b/docs/source/bitermplus.metrics.rst index 77a1731..3a74ab2 100644 --- a/docs/source/bitermplus.metrics.rst +++ b/docs/source/bitermplus.metrics.rst @@ -4,4 +4,5 @@ Metrics .. currentmodule:: bitermplus .. autofunction:: coherence -.. autofunction:: perplexity \ No newline at end of file +.. autofunction:: perplexity +.. autofunction:: entropy \ No newline at end of file diff --git a/setup.py b/setup.py index 5274ea3..da6b605 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, Extension -from Cython.Build import cythonize from platform import system +from Cython.Build import cythonize # from numpy import get_include extra_link_args = ['-lomp'] if system() == 'Darwin' else ['-fopenmp'] @@ -12,12 +12,12 @@ Extension( "bitermplus._btm", sources=["src/bitermplus/_btm.pyx"], - # include_dirs=[get_include()], - # library_dirs=[get_include()], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args), Extension( "bitermplus._metrics", + # include_dirs=[get_include()], + # library_dirs=[get_include()], sources=["src/bitermplus/_metrics.pyx"], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args), diff --git a/src/bitermplus/_metrics.pyx b/src/bitermplus/_metrics.pyx index 639953f..a4c91e1 100644 --- a/src/bitermplus/_metrics.pyx +++ b/src/bitermplus/_metrics.pyx @@ -1,5 +1,6 @@ -__all__ = ['perplexity', 'coherence'] +__all__ = ['perplexity', 'coherence', 'entropy'] +from cython.view cimport array from libc.math cimport exp, log from typing import Union from pandas import DataFrame @@ -160,3 +161,83 @@ cpdef coherence( coherence[t] = logSum return np.array(coherence) + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.cdivision(True) +cpdef entropy( + double[:, :] p_wz): + """Renyi entropy calculation routine [1]_. + + Renyi entropy can be used to estimate the optimal number of topics. + + Parameters + ---------- + p_wz : np.ndarray + Topics vs words probabilities matrix (T x W). + + Returns + ------- + renyi : double + Renyi entropy value. + + References + ---------- + .. [1] Koltcov, S. (2018). Application of Rényi and Tsallis entropies to + topic modeling optimization. Physica A: Statistical Mechanics and its + Applications, 512, 1192-1204. + """ + # Words number + cdef int W = p_wz.shape[1] + # Topics number + cdef int T = p_wz.shape[0] + + # Initializing variables + cdef double word_ratio = 0. + cdef double sum_prob = 0. + cdef double shannon = 0. + cdef double energy = 0. + cdef double int_energy = 0. + cdef double free_energy = 0. + cdef double renyi = 0. + cdef double thresh = 1. + cdef int t = 0 + cdef int w = 0 + + # Setting threshold + thresh /= W + + # Maximum probability of each word + cdef double[:] p_max = array( + shape=(W, ), itemsize=sizeof(double), format="d", + allocate_buffer=True) + p_max[...] = 0. + + for w in range(W): + for t in range(T): + if p_wz[t, w] > p_max[w]: + p_max[w] = p_wz[t, w] + + # Select the probabilities larger than thresh + for w in range(W): + if p_max[w] > thresh: + sum_prob += p_max[w] + word_ratio += 1 + + # Shannon entropy + shannon = log(word_ratio / (W * T)) + + # Internal energy + int_energy = -log(sum_prob / T) + + # Free energy + free_energy = int_energy - shannon * T + + # Renyi entropy + if T == 1: + renyi = free_energy / T + else: + renyi = free_energy / (T-1) + + return renyi diff --git a/tests/test_btm.py b/tests/test_btm.py index 5ea4116..cd5df80 100644 --- a/tests/test_btm.py +++ b/tests/test_btm.py @@ -77,6 +77,12 @@ def test_btm_class(self): self.assertGreater(coherence.shape[0], 0) LOGGER.info('Coherence testing finished') + LOGGER.info('Entropy testing started') + entropy = btm.entropy(model.matrix_topics_words_) + self.assertNotEqual(entropy, 0) + LOGGER.info("Entropy value: {}".format(entropy)) + LOGGER.info('Entropy testing finished') + LOGGER.info('Model loading started') with open('model.pickle', 'rb') as file: self.assertIsInstance(pkl.load(file), btm._btm.BTM)