from jax.config import config; config.update("jax_enable_x64", True)
fig.align_ylabels(ax[:, 0])
def float_to_str(x):
number_sci = f"{x:.1e}"
base, exponent = number_sci.split('e')
exponent = int(exponent) # Convert exponent to integer for proper formatting
# Proper LaTeX representation
latex_sci_notation = r"${{{}}}\times10^{{{}}}$".format(base, exponent)
return latex_sci_notation
Some common, easy-to-forget, code snippets I used for making plots.
As always, if you are in a notebook, start with
%config InlineBackend.figure_format='retina'
to make your plots more eye-friendly
from scipy.ndimage import uniform_filter1d
def smooth(x):
return uniform_filter1d(x, 10)
for a in ax:
legend = ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
legend1 = ax[0].legend(ncol=3) # Legend you got from other plots
# Create a new figure just for the legend
fig_leg = plt.figure(figsize=(2, 0.5)) # You can adjust the size as needed
ax_leg = fig_leg.add_subplot(111)
label_order = [0, 1, 2, 3] # If you wanna switch order for the legneds
leg = ax_leg.legend(
handles=[legend1.legendHandles[i] for i in label_order],
# You could also manually write the names for the labels
labels=[legend1.get_texts()[i].get_text() for i in label_order],
ncol=4, loc='center'
# Adjust the linewidth if you want
for line in leg.get_lines():
# Turn off the axis
fig_leg.savefig('./my_legends.pdf', bbox_inches='tight', pad_inches=0.01)
Then in latex, you can simply do
\end{subfigure} \\[-3.8ex]
fig, ax = plt.subplots(3, 6, figsize=(6.75, 3), sharex='col')
bottom = 0.1
top = 0.85
left = 0.075
right = 0.98
fig.subplots_adjust(wspace=0.6, hspace=0.4, left=left, right=right, bottom=bottom, top=top)
a.set_box_aspect(aspect=1 / 1.618)
def plot_arrows(X, Y, ax, c, alphas=[0.2, 0.6, 1.0]):
for i in range(len(X)):
prop = dict(arrowstyle="simple,head_width=0.5,head_length=0.6",
shrinkA=0,shrinkB=0,facecolor=c, edgecolor=c, alpha=alphas[i],lw=3)
ax.annotate("", xy=(X[i, 1], Y[i, 1]), xytext=(X[i, 0],Y[i, 0]), arrowprops=prop,
color=c, alpha=alphas[i])
You wanna use scientific notation for the y axis labels and only have the coefficient on the axis but exponent on the top
import matplotlib.ticker as mtick
# If you would like 1x10-5, set it as False if like 1e10-5
formatter = mtick.ScalarFormatter(useMathText=True)
# I don't know what (-1, 1) is for :D
# Change the position and font size of the exponent
ax[0].xaxis.set_ticklabels([0, '25K', '50K'])
ax[0].tick_params(axis='x', labelsize=8)
# or
ax.tick_params(axis='both', labelsize=8)
ax.set_ylabel('ELBO', fontsize=18, labelpad=12) # The labelpad does the job
A good website to get colors from the same "family", e.g. you could use the same color family for methods for different variations.
Or if you want to extract colors from the colormap from matplotlib, you could
from matplotlib import cm
my_colors = cm.tab10([0,1,2,4])