diff --git a/src/dynamic_routing_analysis/plot_utils.py b/src/dynamic_routing_analysis/plot_utils.py index 4e8895d..553dba4 100644 --- a/src/dynamic_routing_analysis/plot_utils.py +++ b/src/dynamic_routing_analysis/plot_utils.py @@ -927,8 +927,12 @@ def plot_brain_heatmap( # clean up inputs regions = np.array(regions) values = np.array(values) + if values.shape[0] == 2 and values.shape[1] != 2: + values = values.T if sagittal_planes is None: sagittal_planes = [] + elif not isinstance(sagittal_planes, Iterable): + sagittal_planes = (sagittal_planes, ) else: sagittal_planes = tuple(sagittal_planes) # type: ignore if clevels is None: @@ -937,7 +941,6 @@ def plot_brain_heatmap( clevels = tuple(clevels) # type: ignore if len(clevels) != 2: raise ValueError("clevels must be a sequence of length 2") - # set up kwargs that are shared between all axes joint_kwargs = { 'regions': regions,