diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index 5732ebb98c..b6124f66ab 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -926,6 +926,10 @@ def __init__( ): self.thickness = thickness self.resolution = (2, resolution) if isinstance(resolution, int) else resolution + + start = np.array(start, dtype=np.float64) + end = np.array(end, dtype=np.float64) + self.set_start_and_end_attrs(start, end, **kwargs) if color is not None: self.set_color(color) @@ -1183,8 +1187,9 @@ def __init__( height=height, **kwargs, ) - self.cone.shift(end) - self.end_point = VectorizedPoint(end) + np_end = np.asarray(end, dtype=np.float64) + self.cone.shift(np_end) + self.end_point = VectorizedPoint(np_end) self.add(self.end_point, self.cone) self.set_color(color) diff --git a/tests/test_graphical_units/test_threed.py b/tests/test_graphical_units/test_threed.py index b6079e5e4c..64c003577b 100644 --- a/tests/test_graphical_units/test_threed.py +++ b/tests/test_graphical_units/test_threed.py @@ -164,7 +164,7 @@ def param_surface(u, v): def test_get_start_and_end_Arrow3d(): - start, end = ORIGIN, np.array([2, 1, 0]) + start, end = ORIGIN, np.array([2, 1, 0], dtype=np.float64) arrow = Arrow3D(start, end) assert np.allclose( arrow.get_start(), start, atol=0.01 @@ -172,3 +172,19 @@ def test_get_start_and_end_Arrow3d(): assert np.allclose( arrow.get_end(), end, atol=0.01 ), "end points of Arrow3D do not match" + + +def test_type_conversion_in_Line3D(): + start, end = [0, 0, 0], [1, 1, 1] + line = Line3D(start, end) + type_table = [type(item) for item in [*line.get_start(), *line.get_end()]] + bool_table = [t == np.float64 for t in type_table] + assert all(bool_table), "Types of start and end points are not np.float64" + + +def test_type_conversion_in_Arrow3D(): + start, end = [0, 0, 0], [1, 1, 1] + line = Arrow3D(start, end) + type_table = [type(item) for item in [*line.get_start(), *line.get_end()]] + bool_table = [t == np.float64 for t in type_table] + assert all(bool_table), "Types of start and end points are not np.float64"