From 9c7e1dc14f4b78dd6f244b83a4c282e96b4eecc0 Mon Sep 17 00:00:00 2001 From: chrystal chern <52893467+chrystalchern@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:49:15 -0700 Subject: [PATCH] Update printing.py --- src/mdof/utilities/printing.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mdof/utilities/printing.py b/src/mdof/utilities/printing.py index db035dc..ec8d4c5 100644 --- a/src/mdof/utilities/printing.py +++ b/src/mdof/utilities/printing.py @@ -111,7 +111,9 @@ def plot_models(models, Tn, zeta): def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None), **options): - fig, ax = plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1]))) + fig, ax = options.get('figax', + plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1]))) + ) if len(inputs.shape) > 1: for i in range(inputs.shape[0]): ax[0].plot(t,inputs[i,:],label=f"input {i+1}") @@ -137,7 +139,7 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options): linestyles = ['dashed', 'dashdot', 'dotted'] colors = ['blue', 'orange', 'green', 'magenta'] - fig, ax = plt.subplots(figsize=options.get('figsize',(6,3))) + fig, ax = options.get('figax',plt.subplots(figsize=options.get('figsize',(6,3)))) if len(ytrue.shape) > 1: for i in range(ytrue.shape[0]): ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',linestyle=linestyles[i%len(linestyles)]) @@ -147,9 +149,9 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options): if len(models.shape) > 1: for i in range(models.shape[0]): ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else f"{options.get('single_label','prediction')}", - linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5) + linestyle=linestyles[i%len(linestyles)],linewidth=2,color=options.get('single_color',colors[i%len(colors)]),alpha=0.5) else: - ax.plot(t,models,"--",label=f"{options.get('single_label','prediction')}") + ax.plot(t,models,"--",color=options.get('single_color',None),label=f"{options.get('single_label','prediction')}") else: for k,method in enumerate(models): if len(models[method]["ypred"].shape) > 1: @@ -192,7 +194,7 @@ def plot_transfer(models, title=None, labels=None, plotly=False, **options): fig.show(renderer="notebook_connected") else: linestyles = ['-','--',':'] - fig, ax = plt.subplots(figsize=(6,3)) + fig, ax = plt.subplots(figsize=options.get('figsize',(6,3))) if type(models) is np.ndarray: if len(models.shape) > 2: for i in range(models.shape[0]):