Skip to content

Commit

Permalink
Update printing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chrystalchern committed Jun 8, 2024
1 parent de24ec8 commit 8d74a5e
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions src/mdof/utilities/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
color_iter = cycle(DEFAULT_PLOTLY_COLORS)


def print_modes(modes, Tn=None, zeta=None):
def print_modes(modes, Tn=None, zeta=None, sigfigs=4):

if len(modes) == 0:
print("No valid identified modes.")
Expand All @@ -50,11 +50,11 @@ def print_modes(modes, Tn=None, zeta=None):
z = mode["damp"]
emaco = mode["energy_condensed_emaco"]
mpc = mode["mpc"]
row = f" {1/f: <9.4} {z: <9.4} {emaco: <9.4} {mpc: <9.4} {emaco*mpc: <9.4}"
row = f" {1/f: <9.{sigfigs}} {z: <9.{sigfigs}} {emaco: <9.{sigfigs}} {mpc: <9.{sigfigs}} {emaco*mpc: <9.{sigfigs}}"
if Tn is not None:
row += f" {100*(1/f-Tn)/(Tn): <9.4}"
row += f" {100*(1/f-Tn)/(Tn): <9.{sigfigs}}"
if zeta is not None:
row += f" {100*(z-zeta)/zeta: <9.4}"
row += f" {100*(z-zeta)/zeta: <9.{sigfigs}}"
print(row)
print("Mean Period(s):", np.mean([1/v["freq"] for v in modes.values()]))
print("Standard Dev(s):", np.std([1/v["freq"] for v in modes.values()]))
Expand Down Expand Up @@ -110,8 +110,8 @@ def plot_models(models, Tn, zeta):
# fig.suptitle("Spectral Quantity Prediction with System Identification",fontsize=17)


def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None)):
fig, ax = plt.subplots(1,2,figsize=(10,3),constrained_layout=True,sharey=(ylabels[0]==ylabels[1]))
def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None), **options):
fig, ax = plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1])))
if len(inputs.shape) > 1:
for i in range(inputs.shape[0]):
ax[0].plot(t,inputs[i,:],label=f"input {i+1}")
Expand All @@ -130,13 +130,14 @@ def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitl
ax[1].set_ylabel(ylabels[1], fontsize=15)
ax[1].set_title(axtitles[1], fontsize=15)
fig.suptitle(title, fontsize=17)
return fig


def plot_pred(ytrue, models, t, title=None, ylabel="outputs"):
def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options):
linestyles = ['dashed', 'dashdot', 'dotted']
colors = ['blue', 'orange', 'green', 'magenta']

fig, ax = plt.subplots(figsize=(6,3))
fig, ax = plt.subplots(figsize=options.get('figsize',(6,3)))
if len(ytrue.shape) > 1:
for i in range(ytrue.shape[0]):
ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',linestyle=linestyles[i%len(linestyles)])
Expand All @@ -145,9 +146,10 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs"):
if type(models) is np.ndarray:
if len(models.shape) > 1:
for i in range(models.shape[0]):
ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else "predicition",linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5)
ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else f"{options.get('single_label','prediction')}",
linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5)
else:
ax.plot(t,models,"--",label=f"prediction")
ax.plot(t,models,"--",label=f"{options.get('single_label','prediction')}")
else:
for k,method in enumerate(models):
if len(models[method]["ypred"].shape) > 1:
Expand All @@ -161,9 +163,10 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs"):
ax.set_ylabel(ylabel, fontsize=14)
fig.legend(fontsize=12, frameon=True, framealpha=0.4, bbox_to_anchor=(0.9,0,0.5,0.8), loc='upper left')
fig.suptitle(title, fontsize=14)
return fig


def plot_transfer(models, title=None, labels=None, plotly=False):
def plot_transfer(models, title=None, labels=None, plotly=False, **options):
if plotly:
import plotly.graph_objects as go
layout = go.Layout(
Expand All @@ -172,37 +175,39 @@ def plot_transfer(models, title=None, labels=None, plotly=False):
title="Period (s)"
),
yaxis=dict(
title="Amplitude"
title=options.get('ylabel', "Amplitude")
),
width=600, height=300,
margin=dict(l=70, r=20, t=20, b=20))
fig = go.Figure(layout=layout)
if type(models) is np.ndarray:
if len(models.shape) > 2:
for i in range(models.shape[0]):
fig.add_trace(go.Scatter(x=models[i,0],y=models[i,1]/max(models[i,1]),name=labels[i]))
fig.add_trace(go.Scatter(x=models[i,0],y=models[i,1]/max(models[i,1]),name=labels[i] if labels is not None else None))
else:
fig.add_trace(go.Scatter(x=models[0],y=models[1]/max(models[1]),name=labels))
else:
for method in models:
fig.add_trace(go.Scatter(x=models[method][0],y=models[method][1]/max(models[method][1]),name=method))
fig.show(renderer="notebook_connected")
else:
linestyles = ['-','--',':']
fig, ax = plt.subplots(figsize=(6,3))
if type(models) is np.ndarray:
if len(models.shape) > 2:
for i in range(models.shape[0]):
ax.plot(models[i,0],models[i,1]/max(models[i,1]),label=labels[i])
ax.plot(models[i,0],models[i,1]/max(models[i,1]),linestyle=linestyles[i],label=labels[i] if labels is not None else None)
else:
ax.plot(models[0],models[1]/max(models[1]),label=labels)
else:
for method in models:
ax.plot(models[method][0],models[method][1]/max(models[method][1]),label=method)
ax.set_xlabel("Period (s)")
ax.set_ylabel("Amplitude")
if labels is not None:
ax.legend()#fontsize=12)
ax.set_title(title)#, fontsize=14)
for i,method in enumerate(models):
ax.plot(models[method][0],models[method][1]/max(models[method][1]),linestyle=linestyles[i],label=method)
ax.set_xlabel("Period (s)",fontsize=14)
ax.set_ylabel(options.get('ylabel', "Amplitude"),fontsize=14)
if (labels is not None) or (not type(models) is np.ndarray):
ax.legend(fontsize=12, frameon=True, framealpha=0.4, bbox_to_anchor=(1,0,0.5,0.8), loc='upper left')
ax.set_title(title)
return fig


class FrequencyContent:
Expand Down

0 comments on commit 8d74a5e

Please sign in to comment.