Skip to content

Commit

Permalink
Update printing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chrystalchern committed Jun 11, 2024
1 parent 8d74a5e commit 9c7e1dc
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/mdof/utilities/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def plot_models(models, Tn, zeta):


def plot_io(inputs, outputs, t, title=None, ylabels=("inputs","outputs"), axtitles=(None,None), **options):
fig, ax = plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1])))
fig, ax = options.get('figax',
plt.subplots(1,2,figsize=options.get('figsize',(10,3)),constrained_layout=True,sharey=options.get('sharey',(ylabels[0]==ylabels[1])))
)
if len(inputs.shape) > 1:
for i in range(inputs.shape[0]):
ax[0].plot(t,inputs[i,:],label=f"input {i+1}")
Expand All @@ -137,7 +139,7 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options):
linestyles = ['dashed', 'dashdot', 'dotted']
colors = ['blue', 'orange', 'green', 'magenta']

fig, ax = plt.subplots(figsize=options.get('figsize',(6,3)))
fig, ax = options.get('figax',plt.subplots(figsize=options.get('figsize',(6,3))))
if len(ytrue.shape) > 1:
for i in range(ytrue.shape[0]):
ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',linestyle=linestyles[i%len(linestyles)])
Expand All @@ -147,9 +149,9 @@ def plot_pred(ytrue, models, t, title=None, ylabel="outputs", **options):
if len(models.shape) > 1:
for i in range(models.shape[0]):
ax.plot(t,models[i,:],label=f"prediction, DOF {i+1}" if models.shape[0]>1 else f"{options.get('single_label','prediction')}",
linestyle=linestyles[i%len(linestyles)],linewidth=2,color=colors[i%len(colors)],alpha=0.5)
linestyle=linestyles[i%len(linestyles)],linewidth=2,color=options.get('single_color',colors[i%len(colors)]),alpha=0.5)
else:
ax.plot(t,models,"--",label=f"{options.get('single_label','prediction')}")
ax.plot(t,models,"--",color=options.get('single_color',None),label=f"{options.get('single_label','prediction')}")
else:
for k,method in enumerate(models):
if len(models[method]["ypred"].shape) > 1:
Expand Down Expand Up @@ -192,7 +194,7 @@ def plot_transfer(models, title=None, labels=None, plotly=False, **options):
fig.show(renderer="notebook_connected")
else:
linestyles = ['-','--',':']
fig, ax = plt.subplots(figsize=(6,3))
fig, ax = plt.subplots(figsize=options.get('figsize',(6,3)))
if type(models) is np.ndarray:
if len(models.shape) > 2:
for i in range(models.shape[0]):
Expand Down

0 comments on commit 9c7e1dc

Please sign in to comment.