Skip to content

Commit

Permalink
Add ability to have grouped angles
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellAcoustics committed Nov 14, 2023
1 parent 6c2d6f9 commit 209687e
Show file tree
Hide file tree
Showing 5 changed files with 473 additions and 238 deletions.
141 changes: 141 additions & 0 deletions docs/Introduction to SSM Analysis.ipynb

Large diffs are not rendered by default.

120 changes: 119 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ test = [
]
docs = [
"jupyter>=1.0.0",
"mkdocs>=1.5.3",
]
dev = [
"setuptools>=68.2.2",
Expand Down
48 changes: 37 additions & 11 deletions src/circumplex/circumplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cycler import cycler
import numpy as np
import pandas as pd
import math
from scipy.optimize import curve_fit

OCTANTS = (0, 45, 90, 135, 180, 225, 270, 315)
Expand Down Expand Up @@ -95,7 +96,7 @@ def __str__(self):
# TODO: Add param results
return f"{self.label}: {self.params})"

def profile_plot(self) -> tuple:
def profile_plot(self, ax=None) -> tuple:
"""
Plot the SSM profile.
Expand All @@ -110,6 +111,7 @@ def profile_plot(self) -> tuple:
self.angles,
self.scores,
self.label,
ax=ax,
)

def plot(self):
Expand Down Expand Up @@ -206,16 +208,23 @@ def plot(self, colors=None, legend=True, *args, **kwargs) -> tuple:
fig.legend(loc="upper right", bbox_to_anchor=(1.2, 1))
return fig, ax

def profile_plots(self) -> None:
def profile_plots(self, axes=None) -> None:
"""
Plot the SSM profiles.
Returns:
None
"""
for res in self.results:
fig, ax = res.profile_plot()
plt.show()
if axes is None:
fig, axes = plt.subplots(
nrows=len(self.results),
figsize=(8, 4 * len(self.results)),
)
for i, res in enumerate(self.results):
fig, ax = res.profile_plot(ax=axes.flatten()[i])
plt.tight_layout()
# plt.show()
return fig, axes


# %%
Expand All @@ -227,6 +236,7 @@ def ssm_analyse(
measures: list | None = None,
grouping: list | None = None,
angles: tuple = OCTANTS,
grouped_angles: dict = None,
) -> SSMResults:
"""
Analyse a set of data using the SSM method.
Expand All @@ -245,11 +255,13 @@ def ssm_analyse(
"""
if grouping is not None and measures is not None:
return ssm_analyse_grouped_corrs(data, scales, measures, grouping, angles)
return ssm_analyse_grouped_corrs(
data, scales, measures, grouping, angles, grouped_angles
)
elif measures is not None:
return ssm_analyse_corrs(data, scales, measures, angles)
elif grouping is not None:
return ssm_analyse_means(data, scales, grouping, angles)
return ssm_analyse_means(data, scales, grouping, angles, grouped_angles)
else:
ssm = SSMParams(data[scales].mean(), scales, angles)
# ssm.param_calc()
Expand All @@ -262,6 +274,7 @@ def ssm_analyse_grouped_corrs(
measures: list,
grouping: list,
angles: tuple = OCTANTS,
grouped_angles: dict = None,
) -> SSMResults:
"""
Perform SSM analysis of correlations for a set of grouped data.
Expand All @@ -281,6 +294,8 @@ def ssm_analyse_grouped_corrs(
res = []
for group_var in grouping:
for group, group_data in data.groupby(group_var):
if grouped_angles is not None:
angles = grouped_angles[group] # grouped angles will override angles
try:
res.append(
ssm_analyse_corrs(
Expand Down Expand Up @@ -326,7 +341,11 @@ def ssm_analyse_corrs(


def ssm_analyse_means(
data: pd.DataFrame, scales: tuple, grouping: list, angles: tuple = OCTANTS
data: pd.DataFrame,
scales: tuple,
grouping: list,
angles: tuple = OCTANTS,
grouped_angles: dict = None,
) -> SSMResults:
"""
Perform SSM analysis of means for a set of data.
Expand All @@ -345,6 +364,8 @@ def ssm_analyse_means(
means = data.groupby(grouping)[scales].mean()
res = []
for group, scores in means.iterrows():
if grouped_angles is not None:
angles = grouped_angles[group]
scores = means.loc[group]
ssm = SSMParams(scores, scales, angles, group=group)
# ssm.param_calc()
Expand Down Expand Up @@ -393,7 +414,6 @@ def ssm_parameters(scores, angles, bounds=BOUNDS):
)
r2 = _r2_score(scores, cosine_form(angles, *param))
ampl, disp, elev = param
disp = disp - 360 if disp > 360 else disp

def polar2cart(r, theta):
x = r * np.cos(theta)
Expand All @@ -404,7 +424,9 @@ def polar2cart(r, theta):
return elev, xval, yval, ampl, disp, r2


def profile_plot(amplitude, displacement, elevation, r2, angles, scores, label):
def profile_plot(
amplitude, displacement, elevation, r2, angles, scores, label, ax=None
):
"""
Plot the SSM profile.
Expand All @@ -414,7 +436,11 @@ def profile_plot(amplitude, displacement, elevation, r2, angles, scores, label):
thetas = np.linspace(0, 360, 1000)
fit = cosine_form(thetas, amplitude, displacement, elevation)

fig, ax = plt.subplots()
if ax is None:
fig, ax = plt.subplots(figsize=(8, 4))
else:
fig = ax.get_figure()

ax.plot(thetas, fit, color="black")
ax.plot(angles, scores, color="red", marker="o")
# ax.scatter(self.angles, self.scores, marker="o", color="black")
Expand Down
401 changes: 175 additions & 226 deletions tests/Intro_SSM.ipynb

Large diffs are not rendered by default.

0 comments on commit 209687e

Please sign in to comment.