Skip to content

Commit

Permalink
plots printing utils
Browse files Browse the repository at this point in the history
  • Loading branch information
chrystalchern committed Oct 11, 2024
1 parent 30e4a3d commit 717e1fd
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/mdof/utilities/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,16 @@ def plot_io(inputs, outputs, t, title=None, xlabels=("time (s)", "time (s)"), yl
def plot_pred(ytrue, models, t, title=None, xlabel="time (s)", ylabel="outputs", makelegend=True, **options):
linestyles = ['dashed', 'dashdot', 'dotted']
colors = ['blue', 'orange', 'green', 'magenta']
true_first = options['true_first'] if 'true_first' in options.keys() else False

# fig, ax = options.get('figax',plt.subplots(figsize=options.get('figsize',(6,3))))
fig, ax = options['figax'] if 'figax' in options.keys() else plt.subplots(figsize=options.get('figsize',(6,3)))
if len(ytrue.shape) > 1:
for i in range(ytrue.shape[0]):
ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',linestyle=linestyles[i%len(linestyles)])
else:
ax.plot(t,ytrue,label="true",color='black')
if true_first:
if len(ytrue.shape) > 1:
for i in range(ytrue.shape[0]):
ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',alpha=0.5,linestyle=linestyles[i%len(linestyles)])
else:
ax.plot(t,ytrue,label="true",color='black',alpha=0.5)
if type(models) is np.ndarray:
if len(models.shape) > 1:
for i in range(models.shape[0]):
Expand All @@ -158,6 +160,12 @@ def plot_pred(ytrue, models, t, title=None, xlabel="time (s)", ylabel="outputs",
label=f"{method.upper()}, DOF {i+1}" if models[method]["ypred"].shape[0]>1 else f"{method.upper()}")
else:
ax.plot(t,models[method]["ypred"],linestyle=linestyles[k%len(linestyles)],linewidth=2,color=colors[k],alpha=0.5,label=method.upper())
if not true_first:
if len(ytrue.shape) > 1:
for i in range(ytrue.shape[0]):
ax.plot(t,ytrue[i,:],label=f"true, DOF {i+1}",color='black',alpha=0.5,linestyle=linestyles[i%len(linestyles)])
else:
ax.plot(t,ytrue,label="true",color='black',alpha=0.5)
ax.set_xlabel(xlabel, fontsize=14)
ax.set_ylabel(ylabel, fontsize=14)
if makelegend:
Expand Down

0 comments on commit 717e1fd

Please sign in to comment.