From 8d74a5e8cb94fdccefcea9c290d6d1b9a75afe9e Mon Sep 17 00:00:00 2001 From: chrystal chern <52893467+chrystalchern@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:05:07 -0700 Subject: [PATCH] Update printing.py --- src/mdof/utilities/printing.py | 47 +++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/mdof/utilities/printing.py b/src/mdof/utilities/printing.py index 195d15e..db035dc 100644 --- a/src/mdof/utilities/printing.py +++ b/src/mdof/utilities/printing.py @@ -29,7 +29,7 @@ color_iter = cycle(DEFAULT_PLOTLY_COLORS) -def print_modes(modes, Tn=None, zeta=None): +def print_modes(modes, Tn=None, zeta=None, sigfigs=4): if len(modes) == 0: print("No valid identified modes.") @@ -50,11 +50,11 @@ def print_modes(modes, Tn=None, zeta=None): z = mode["damp"] emaco = mode["energy_condensed_emaco"] mpc = mode["mpc"] - row = f" {1/f: <9.4} {z: <9.4} {emaco: <9.4} {mpc: <9.4} {emaco*mpc: <9.4}" + row = f" {1/f: <9.{sigfigs}} {z: <9.{sigfigs}} {emaco: <9.{sigfigs}} {mpc: <9.{sigfigs}} {emaco*mpc: <9.{sigfigs}}" if Tn is not None: - row += f" {100*(1/f-Tn)/(Tn): <9.4}" + row += f" {100*(1/f-Tn)/(Tn): <9.{sigfigs}}" if zeta is not None: - row += f" {100*(z-zeta)/zeta: <9.4}" + row += f" {100*(z-zeta)/zeta: <9.{sigfigs}}" print(row) print("Mean Period(s):", np.mean([1/v["freq"] for v in modes.values()])) print("Standard Dev(s):", np.std([1/v["freq"] for v in modes.values()])) @@ -110,8 +110,8 @@ def plot_models(models, Tn, zeta): # fig.suptitle("Spectral Quantity Prediction with System Identification",fontsize=17) -def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None)): - fig, ax = plt.subplots(1,2,figsize=(10,3),constrained_layout=True,sharey=(ylabels[0]==ylabels[1])) +def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None), **options): + fig, ax = plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1]))) if len(inputs.shape) > 1: for i in range(inputs.shape[0]): ax[0].plot(t,inputs[i,:],label=f"input {i+1}") @@ -130,13 +130,14 @@ def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitl ax[1].set_ylabel(ylabels[1], fontsize=15) ax[1].set_title(axtitles[1], fontsize=15) fig.suptitle(title, fontsize=17) + return fig -def plot_pred(ytrue, models, t, title=None, ylabel="outputs"): +def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options): linestyles = ['dashed', 'dashdot', 'dotted'] colors = ['blue', 'orange', 'green', 'magenta'] - fig, ax = plt.subplots(figsize=(6,3)) + fig, ax = plt.subplots(figsize=options.get('figsize',(6,3))) if len(ytrue.shape) > 1: for i in range(ytrue.shape[0]): ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',linestyle=linestyles[i%len(linestyles)]) @@ -145,9 +146,10 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs"): if type(models) is np.ndarray: if len(models.shape) > 1: for i in range(models.shape[0]): - ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else "predicition",linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5) + ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else f"{options.get('single_label','prediction')}", + linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5) else: - ax.plot(t,models,"--",label=f"prediction") + ax.plot(t,models,"--",label=f"{options.get('single_label','prediction')}") else: for k,method in enumerate(models): if len(models[method]["ypred"].shape) > 1: @@ -161,9 +163,10 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs"): ax.set_ylabel(ylabel, fontsize=14) fig.legend(fontsize=12, frameon=True, framealpha=0.4, bbox_to_anchor=(0.9,0,0.5,0.8), loc='upper left') fig.suptitle(title, fontsize=14) + return fig -def plot_transfer(models, title=None, labels=None, plotly=False): +def plot_transfer(models, title=None, labels=None, plotly=False, **options): if plotly: import plotly.graph_objects as go layout = go.Layout( @@ -172,7 +175,7 @@ def plot_transfer(models, title=None, labels=None, plotly=False): title="Period (s)" ), yaxis=dict( - title="Amplitude" + title=options.get('ylabel', "Amplitude") ), width=600, height=300, margin=dict(l=70, r=20, t=20, b=20)) @@ -180,7 +183,7 @@ def plot_transfer(models, title=None, labels=None, plotly=False): if type(models) is np.ndarray: if len(models.shape) > 2: for i in range(models.shape[0]): - fig.add_trace(go.Scatter(x=models[i,0],y=models[i,1]/max(models[i,1]),name=labels[i])) + fig.add_trace(go.Scatter(x=models[i,0],y=models[i,1]/max(models[i,1]),name=labels[i] if labels is not None else None)) else: fig.add_trace(go.Scatter(x=models[0],y=models[1]/max(models[1]),name=labels)) else: @@ -188,21 +191,23 @@ def plot_transfer(models, title=None, labels=None, plotly=False): fig.add_trace(go.Scatter(x=models[method][0],y=models[method][1]/max(models[method][1]),name=method)) fig.show(renderer="notebook_connected") else: + linestyles = ['-','--',':'] fig, ax = plt.subplots(figsize=(6,3)) if type(models) is np.ndarray: if len(models.shape) > 2: for i in range(models.shape[0]): - ax.plot(models[i,0],models[i,1]/max(models[i,1]),label=labels[i]) + ax.plot(models[i,0],models[i,1]/max(models[i,1]),linestyle=linestyles[i],label=labels[i] if labels is not None else None) else: ax.plot(models[0],models[1]/max(models[1]),label=labels) else: - for method in models: - ax.plot(models[method][0],models[method][1]/max(models[method][1]),label=method) - ax.set_xlabel("Period (s)") - ax.set_ylabel("Amplitude") - if labels is not None: - ax.legend()#fontsize=12) - ax.set_title(title)#, fontsize=14) + for i,method in enumerate(models): + ax.plot(models[method][0],models[method][1]/max(models[method][1]),linestyle=linestyles[i],label=method) + ax.set_xlabel("Period (s)",fontsize=14) + ax.set_ylabel(options.get('ylabel', "Amplitude"),fontsize=14) + if (labels is not None) or (not type(models) is np.ndarray): + ax.legend(fontsize=12, frameon=True, framealpha=0.4, bbox_to_anchor=(1,0,0.5,0.8), loc='upper left') + ax.set_title(title) + return fig class FrequencyContent: