Skip to content

Commit

Permalink
Avoiding interference between the colormap and axis labels
Browse files Browse the repository at this point in the history
  • Loading branch information
oMuransky committed Jul 10, 2024
1 parent 78d9152 commit a9dd88b
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions neml/cp/polefigures.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ def pol2cart(R, T):

return X, Y

def inverse_pole_figure_discrete(orientations, direction, lattice,
reduce_figure = False, color = None,
sample_symmetry = crystallography.symmetry_rotations("222"),
x = [1,0,0], y = [0,1,0], axis_labels = None, nline = 100):
def inverse_pole_figure_discrete(orientations, direction, lattice,
reduce_figure=False, color=None,
sample_symmetry=crystallography.symmetry_rotations(
"222"),
x=[1, 0, 0], y=[0, 1, 0], axis_labels=None, nline=100):
"""
Plot an inverse pole figure given a collection of discrete points.
Expand All @@ -171,18 +172,19 @@ def inverse_pole_figure_discrete(orientations, direction, lattice,
axis_labels: axis labels to include on the figure
nline: number of discrete points to use in plotting lines on the triangle
"""
pts = np.vstack(tuple(project_ipf(q, lattice, direction,
sample_symmetry = sample_symmetry, x = x, y = y) for q in orientations))
pts = np.vstack(tuple(project_ipf(q, lattice, direction,
sample_symmetry=sample_symmetry, x=x, y=y) for q in orientations))

if reduce_figure:
if reduce_figure == "cubic":
vs = (np.array([0,0,1.0]), np.array([1.0,0,1]), np.array([1.0,1,1]))
vs = (np.array([0, 0, 1.0]), np.array(
[1.0, 0, 1]), np.array([1.0, 1, 1]))
elif len(reduce_figure) == 3:
vs = reduce_figure
else:
raise ValueError("Unknown reduction type %s!" % reduce_figure)
pts = reduce_points_triangle(pts, v0 = vs[0], v1=vs[1], v2=vs[2])

pts = reduce_points_triangle(pts, v0=vs[0], v1=vs[1], v2=vs[2])

pop = project_stereographic
lim = limit_stereographic
Expand All @@ -197,32 +199,43 @@ def inverse_pole_figure_discrete(orientations, direction, lattice,
full_color = np.zeros((len(cpoints),))
full_color[::2] = color
full_color[1::2] = color
sc = ax.scatter(cpoints[:,0], cpoints[:,1], c=full_color, s = 10.0)
plt.colorbar(sc)
sc = ax.scatter(cpoints[:, 0], cpoints[:, 1], c=full_color, s=10.0)
# plt.colorbar(sc)
# Add a horizontal colorbar at the bottom
cbar = plt.colorbar(sc, ax=ax, orientation='horizontal')
# [left, bottom, width, height]
cbar.ax.set_position([0.2, 0.02, 0.6, 0.3])
if axis_labels:
plt.text(0.12, 0.33, axis_labels[0], transform=plt.gcf().transFigure)
plt.text(0.84, 0.33, axis_labels[1], transform=plt.gcf().transFigure)
plt.text(0.76, 0.87, axis_labels[2], transform=plt.gcf().transFigure)
elif color == "rgb":
rgb = ipf_color(pts, v0 = vs[0], v1 = vs[1], v2=vs[2])
ax.scatter(cpoints[:,0], cpoints[:,1], c=rgb, s = 10.0)
rgb = ipf_color(pts, v0=vs[0], v1=vs[1], v2=vs[2])
ax.scatter(cpoints[:, 0], cpoints[:, 1], c=rgb, s=10.0)
else:
ax.scatter(cpoints[:,0], cpoints[:,1], c='k', s = 10.0)
ax.scatter(cpoints[:, 0], cpoints[:, 1], c='k', s=10.0)
ax.axis('off')
if axis_labels:
plt.text(0.1,0.11,axis_labels[0], transform = plt.gcf().transFigure)
plt.text(0.86,0.11,axis_labels[1], transform = plt.gcf().transFigure)
plt.text(0.74,0.88,axis_labels[2], transform = plt.gcf().transFigure)

for i,j in ((0,1),(1,2),(2,0)):
if not (hasattr(color, '__len__') and axis_labels):
if axis_labels:
plt.text(0.1, 0.11, axis_labels[0],
transform=plt.gcf().transFigure)
plt.text(0.86, 0.11, axis_labels[1],
transform=plt.gcf().transFigure)
plt.text(0.74, 0.88, axis_labels[2],
transform=plt.gcf().transFigure)
for i, j in ((0, 1), (1, 2), (2, 0)):
v1 = vs[i]
v2 = vs[j]
fs = np.linspace(0,1,nline)
fs = np.linspace(0, 1, nline)
pts = np.array([pop((f*v1+(1-f)*v2)/la.norm(f*v1+(1-f)*v2)) for f in fs])
plt.plot(pts[:,0], pts[:,1], color = 'k')
plt.plot(pts[:, 0], pts[:, 1], color='k')
else:
polar = np.array([cart2pol(v) for v in cpoints])
ax = plt.subplot(111, projection='polar')
ax.scatter(polar[:,0], polar[:,1], c='k', s=10.0)
plt.ylim([0,lim])
ax.scatter(polar[:, 0], polar[:, 1], c='k', s=10.0)
plt.ylim([0, lim])
ax.grid(False)
ax.get_xaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.xaxis.set_minor_locator(plt.NullLocator())

Expand Down

0 comments on commit a9dd88b

Please sign in to comment.