Skip to content

Commit

Permalink
Revert "Speed up FxImporter by avoiding copying constants." (#4)
Browse files Browse the repository at this point in the history
This reverts commit 2c5c287.
  • Loading branch information
JamesMBartlett committed Oct 18, 2024
1 parent b45ee5a commit 4dde7c5
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# python less than 3.10 doesn't have NoneType
NoneType = type(None)

import ctypes
import logging
import operator
import re
Expand Down Expand Up @@ -863,7 +862,7 @@ def import_frozen_program(

# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
if hasattr(prog, 'constants') and state_name in prog.constants:
if hasattr(prog, "constants") and state_name in prog.constants:
state_value = prog.constants[state_name]
else:
try:
Expand Down Expand Up @@ -2076,21 +2075,13 @@ def _make_vtensor_literal_op(
assert (
npy_dtype is not None
), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}"

if tensor.is_contiguous():
cpu_tensor = tensor.cpu()
# If the tensor is contiguous and on CPU we can avoid the costly Tensor -> list -> numpy array below.
# Instead we can use ctypes to access the underlying bytes directly.
buffer_pointer = ctypes.cast(cpu_tensor.untyped_storage().data_ptr(), ctypes.POINTER(ctypes.c_char * cpu_tensor.untyped_storage().nbytes()))
np_tensor = np.frombuffer(bytes(buffer_pointer.contents), dtype=npy_dtype).reshape(cpu_tensor.shape)
else:
# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,
# but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get
# a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,
# but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get
# a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
# 0d tensors.
Expand Down

0 comments on commit 4dde7c5

Please sign in to comment.