Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed problem with type conversions in Line3D #4080

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
9 changes: 7 additions & 2 deletions manim/mobject/three_d/three_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,10 @@ def __init__(
):
self.thickness = thickness
self.resolution = (2, resolution) if isinstance(resolution, int) else resolution

start = np.array(start, dtype=np.float64)
end = np.array(end, dtype=np.float64)

self.set_start_and_end_attrs(start, end, **kwargs)
if color is not None:
self.set_color(color)
Expand Down Expand Up @@ -1183,8 +1187,9 @@ def __init__(
height=height,
**kwargs,
)
self.cone.shift(end)
self.end_point = VectorizedPoint(end)
np_end = np.asarray(end, dtype=np.float64)
self.cone.shift(np_end)
self.end_point = VectorizedPoint(np_end)
self.add(self.end_point, self.cone)
self.set_color(color)

Expand Down
18 changes: 17 additions & 1 deletion tests/test_graphical_units/test_threed.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,27 @@ def param_surface(u, v):


def test_get_start_and_end_Arrow3d():
start, end = ORIGIN, np.array([2, 1, 0])
start, end = ORIGIN, np.array([2, 1, 0], dtype=np.float64)
arrow = Arrow3D(start, end)
assert np.allclose(
arrow.get_start(), start, atol=0.01
), "start points of Arrow3D do not match"
assert np.allclose(
arrow.get_end(), end, atol=0.01
), "end points of Arrow3D do not match"


def test_type_conversion_in_Line3D():
start, end = [0, 0, 0], [1, 1, 1]
line = Line3D(start, end)
type_table = [type(item) for item in [*line.get_start(), *line.get_end()]]
bool_table = [t == np.float64 for t in type_table]
assert all(bool_table), "Types of start and end points are not np.float64"


def test_type_conversion_in_Arrow3D():
start, end = [0, 0, 0], [1, 1, 1]
line = Arrow3D(start, end)
type_table = [type(item) for item in [*line.get_start(), *line.get_end()]]
bool_table = [t == np.float64 for t in type_table]
assert all(bool_table), "Types of start and end points are not np.float64"
Loading