Skip to content

Commit

Permalink
refactor get_relative/absolute_freq for consistency
Browse files Browse the repository at this point in the history
Returning results in the same structure format
Fixed issues with different col_names
  • Loading branch information
JPapir committed Sep 23, 2024
1 parent c79905d commit 1e6c39e
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions src/qumin/representations/frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def initialize(cls, package, default_source=False, **kwargs):
Arguments:
p (frictionless.Package): package to analyze
default_source (Dict[str, str]): name of the source to use when several are available.
**kwargs: keyword arguments for frequency reading methods.
"""

cls.p = package
Expand Down Expand Up @@ -180,7 +181,7 @@ def _read_other_frequencies(cls, name, real=True):
setattr(cls, name, table[['value', 'source']])

@classmethod
def get_absolute_freq(cls, mean=False, group_on=None, skipna=False, **kwargs):
def get_absolute_freq(cls, mean=False, group_on=False, skipna=False, **kwargs):
"""
Return the frequency of an item for a given source
Expand All @@ -191,9 +192,14 @@ def get_absolute_freq(cls, mean=False, group_on=None, skipna=False, **kwargs):
>>> p = fl.Package('tests/data/TestPackage/test.package.json')
>>> Frequencies.initialize(p, real=True)
>>> Frequencies.get_absolute_freq(filters={'lexeme':'q'}, skipna=True)
38.0
>>> Frequencies.get_absolute_freq(filters={'lexeme':'q'}, skipna=False)
>>> Frequencies.get_absolute_freq(filters={'lexeme':'q'}, group_on="index", skipna=True)
form_id
11 12.0
12 6.0
14 20.0
18 NaN
Name: value, dtype: float64
>>> Frequencies.get_absolute_freq(filters={'lexeme':'q'})
nan
>>> Frequencies.get_absolute_freq(filters={'cell':'third'}, mean=True, skipna=True)
20.0
Expand All @@ -206,26 +212,40 @@ def get_absolute_freq(cls, mean=False, group_on=None, skipna=False, **kwargs):
Name: value, dtype: float64
Arguments:
group_on (List[str]): columns on for which absolute frequencies should be computed.
If `"index"` is provided, simply returns the frequency stored.
If `False, aggregates across all records.
mean (bool): Defaults to False. If True, returns a mean instead of a sum.
skipna(bool): Defaults to False. Skip nan values for sums or means.
Returns:
`pandas.Series`: a Series which contains the output values.
The index is either the original one, or the grouping columns.
"""

# Filter using keys from mapping dict
sublist = cls._filter_frequencies(**kwargs)

if group_on is None:
if mean:
return sublist.value.mean(skipna=skipna)
else:
return sublist.value.sum(skipna=skipna)
if group_on == "index":
return sublist.value
elif group_on is False:
groups = [True] * len(sublist)
else:
if mean:
return sublist.groupby(by=group_on, group_keys=False).value.apply(lambda x: x.mean(skipna=skipna))
else:
return sublist.groupby(by=group_on, group_keys=False).value.apply(lambda x: x.sum(skipna=skipna))
groups = group_on

if mean:
def func(x): return x.mean(skipna=skipna)
else:
def func(x): return x.sum(skipna=skipna)

result = sublist.groupby(by=groups, group_keys=False).value.apply(func)

if group_on is False:
return result.iloc[0]
else:
return result

@classmethod
def get_relative_freq(cls, group_on, **kwargs):
def get_relative_freq(cls, group_on=False, **kwargs):
"""
Returns the relative frequencies of a set of rows according to a set of grouping columns.
If any of the values is empty, we generate a Uniform distribution for this group.
Expand Down Expand Up @@ -253,35 +273,48 @@ def get_relative_freq(cls, group_on, **kwargs):
Arguments:
group_on (List[str]): column on which relative frequencies should be computed
Returns:
`pandas.DataFrame`: a DataFrame which contains a `result` column with the output value.
The index is the original one. The grouping columns are also provided.
"""

# Filter using keys from mapping dict
sublist = cls._filter_frequencies(**kwargs)

if group_on is False:
groups = [True] * len(sublist)
col_names = list()
else:
groups = group_on
col_names = list(group_on)

# 1. We first get the nb of items in each group
sublist['result'] = sublist\
.groupby(group_on, sort=False).value\
.groupby(groups, sort=False).value\
.transform("size")

sublist['result'] = sublist.result.astype('float64')
sublist.result = sublist.result.astype('float64')

# 2. If there are any NaN values, we give a uniform frequency to the group
sublist['notna'] = True
sublist.loc[sublist.value.isna(), 'notna'] = False
nanval = (sublist.result != 1) & ~sublist.groupby(group_on).notna.transform('all')
nanval = (sublist.result != 1) & ~sublist.groupby(groups).notna.transform('all')
sublist.loc[nanval, 'result'] = 1/sublist.loc[nanval, 'result']

# 3. If all values are filled and if the group is bigger than one, we sum the frequencies
selector = sublist['result'] > 1
selector = sublist.result > 1

if group_on is False:
groups = selector

sublist.loc[selector, 'result'] = sublist.loc[selector, 'value']/sublist.loc[selector]\
.groupby(group_on, sort=False).value.transform('sum')
.groupby(groups, sort=False).value.transform('sum')

sublist.reset_index(inplace=True)
sublist.set_index(cls.col_names, inplace=True)
return sublist
return sublist[col_names + ["result"]]

@classmethod
def _filter_frequencies(cls, source=None, filters={}, data="forms", inplace=False):
def _filter_frequencies(cls, data="forms", source=None, filters={}, inplace=False):
"""Filters the dataframe based on a set of filters
provided as a dictionary.
Expand Down

0 comments on commit 1e6c39e

Please sign in to comment.