diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 8c800c6c52c5..cdc983739ea8 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -11,7 +11,6 @@ # python less than 3.10 doesn't have NoneType NoneType = type(None) -import ctypes import logging import operator import re @@ -863,7 +862,7 @@ def import_frozen_program( # Lift buffers. for input_name, state_name in sig.inputs_to_buffers.items(): - if hasattr(prog, 'constants') and state_name in prog.constants: + if hasattr(prog, "constants") and state_name in prog.constants: state_value = prog.constants[state_name] else: try: @@ -2076,21 +2075,13 @@ def _make_vtensor_literal_op( assert ( npy_dtype is not None ), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}" - - if tensor.is_contiguous(): - cpu_tensor = tensor.cpu() - # If the tensor is contiguous and on CPU we can avoid the costly Tensor -> list -> numpy array below. - # Instead we can use ctypes to access the underlying bytes directly. - buffer_pointer = ctypes.cast(cpu_tensor.untyped_storage().data_ptr(), ctypes.POINTER(ctypes.c_char * cpu_tensor.untyped_storage().nbytes())) - np_tensor = np.frombuffer(bytes(buffer_pointer.contents), dtype=npy_dtype).reshape(cpu_tensor.shape) - else: - # We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal, - # but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get - # a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to - # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw - # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as - # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) - np_tensor = np.array(tensor.tolist()).astype(npy_dtype) + # We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal, + # but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get + # a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to + # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw + # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as + # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) + np_tensor = np.array(tensor.tolist()).astype(npy_dtype) # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling # 0d tensors.