-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #298 from rsagroup/rdms-to-pandas
RDMs to Pandas DataFrame
- Loading branch information
Showing
5 changed files
with
112 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""Conversions from rsatoolbox classes to pandas table objects | ||
""" | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING | ||
from pandas import DataFrame | ||
import numpy | ||
from numpy import asarray | ||
if TYPE_CHECKING: | ||
from rsatoolbox.rdm.rdms import RDMs | ||
|
||
|
||
def rdms_to_df(rdms: RDMs) -> DataFrame: | ||
"""Create DataFrame representation of the RDMs object | ||
A column for: | ||
- dissimilarity | ||
- each rdm descriptor | ||
- two for each pattern descriptor, suffixed by _1 and _2 respectively | ||
Multiple RDMs are stacked row-wise. | ||
See also the `RDMs.to_df()` method which calls this function | ||
Args: | ||
rdms (RDMs): the object to convert | ||
Returns: | ||
DataFrame: long-form pandas DataFrame with | ||
dissimilarities and descriptors. | ||
""" | ||
n_rdms, n_pairs = rdms.dissimilarities.shape | ||
cols = dict(dissimilarity=rdms.dissimilarities.ravel()) | ||
for dname, dvals in rdms.rdm_descriptors.items(): | ||
# rename the default index desc as that has special meaning in df | ||
cname = 'rdm_index' if dname == 'index' else dname | ||
cols[cname] = numpy.repeat(dvals, n_pairs) | ||
for dname, dvals in rdms.pattern_descriptors.items(): | ||
ix = numpy.triu_indices(len(dvals), 1) | ||
# rename the default index desc as that has special meaning in df | ||
cname = 'pattern_index' if dname == 'index' else dname | ||
for p in (0, 1): | ||
cols[f'{cname}_{p+1}'] = numpy.tile(asarray(dvals)[ix[p]], n_rdms) | ||
return DataFrame(cols) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import annotations | ||
from unittest import TestCase | ||
from typing import TYPE_CHECKING, Union, List | ||
from numpy.testing import assert_array_equal | ||
import numpy | ||
from pandas import Series, DataFrame | ||
if TYPE_CHECKING: | ||
from numpy.typing import NDArray | ||
|
||
|
||
class RdmsToPandasTests(TestCase): | ||
|
||
def assertValuesEqual(self, | ||
actual: Series, | ||
expected: Union[NDArray, List]): | ||
assert_array_equal(numpy.asarray(actual.values), expected) | ||
|
||
def test_to_df(self): | ||
"""Convert an RDMs object to a pandas DataFrame | ||
Default is long form; multiple rdms are stacked row-wise. | ||
""" | ||
from rsatoolbox.rdm.rdms import RDMs | ||
dissimilarities = numpy.random.rand(2, 6) | ||
rdms = RDMs( | ||
dissimilarities, | ||
rdm_descriptors=dict(xy=[c for c in 'xy']), | ||
pattern_descriptors=dict(abcd=numpy.asarray([c for c in 'abcd'])) | ||
) | ||
df = rdms.to_df() | ||
self.assertIsInstance(df, DataFrame) | ||
self.assertEqual(len(df.columns), 7) | ||
self.assertValuesEqual(df.dissimilarity, dissimilarities.ravel()) | ||
self.assertValuesEqual(df['rdm_index'], ([0]*6) + ([1]*6)) | ||
self.assertValuesEqual(df['xy'], (['x']*6) + (['y']*6)) | ||
self.assertValuesEqual(df['pattern_index_1'], | ||
([0]*3 + [1]*2 + [2]*1)*2) | ||
self.assertValuesEqual(df['pattern_index_2'], [1, 2, 3, 2, 3, 3]*2) | ||
self.assertValuesEqual(df['abcd_1'], [c for c in 'aaabbc']*2) | ||
self.assertValuesEqual(df['abcd_2'], [c for c in 'bcdcdd']*2) |