Skip to content

Commit

Permalink
[AOTInductor] remove CUDA dependency for cpp backend (pytorch#110409)
Browse files Browse the repository at this point in the history
Summary:
Previously, we link against cuda libs even for pure cpp backend.
This caused issues for cases where the inference platform does not
have GPUs. This diff removed cuda dependency for cpp backend.

Reviewed By: bertmaher, muchulee8, mikekgfb

Differential Revision: D49800712

Pull Request resolved: pytorch#110409
Approved by: https://github.com/bertmaher, https://github.com/desertfire
  • Loading branch information
chenyang78 authored and pytorchmergebot committed Oct 3, 2023
1 parent df3ab70 commit da63c7f
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 65 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load(cls, model, example_inputs, options=None, constraints=None):
cpp_sources=[launcher],
functions=["run"],
extra_ldflags=[so_path],
with_cuda=True, # TODO: change this to not is_cpu
with_cuda=not is_cpu,
).run

return optimized, exported
Expand Down
31 changes: 15 additions & 16 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,6 @@ def get_include_and_linking_paths(
os.environ["CUDA_HOME"] = os.path.dirname(build_paths.cuda())
from torch.utils import cpp_extension

if aot_mode:
# Hack. The AOT inductor libs reference CUDA, so let's just include it for now.
cuda = True

macros = ""
if sys.platform == "linux" and (
include_pytorch
Expand All @@ -819,17 +815,18 @@ def get_include_and_linking_paths(
libs += ["omp"]
if aot_mode:
ipaths += [os.path.dirname(cpp_prefix_path())]
# This is a special treatment for Meta internal cuda-12 where all libs
# are in lib/cuda-12 and lib/cuda-12/stubs
for i, path in enumerate(lpaths):
if path.startswith(os.environ["CUDA_HOME"]) and not os.path.exists(
f"{path}/libcudart_static.a"
):
for root, dirs, files in os.walk(path):
if "libcudart_static.a" in files:
lpaths[i] = os.path.join(path, root)
lpaths.append(os.path.join(lpaths[i], "stubs"))
break
if cuda:
# This is a special treatment for Meta internal cuda-12 where all libs
# are in lib/cuda-12 and lib/cuda-12/stubs
for i, path in enumerate(lpaths):
if path.startswith(
os.environ["CUDA_HOME"]
) and not os.path.exists(f"{path}/libcudart_static.a"):
for root, dirs, files in os.walk(path):
if "libcudart_static.a" in files:
lpaths[i] = os.path.join(path, root)
lpaths.append(os.path.join(lpaths[i], "stubs"))
break
macros = vec_isa.build_macro()
if macros:
if config.is_fbcode() and vec_isa != invalid_vec_isa:
Expand Down Expand Up @@ -861,6 +858,8 @@ def get_include_and_linking_paths(
# For those cases, include the lpath and libs command as we do for pytorch above.
# This approach allows us to only pay for what we use.
ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")]
if aot_mode:
ipaths += [os.path.dirname(cpp_prefix_path())]
lpaths = []
if sys.platform == "darwin":
# only Apple builtin compilers (Apple Clang++) require openmp
Expand Down Expand Up @@ -921,7 +920,7 @@ def get_include_and_linking_paths(
ipaths.append("include")

static_link_libs = []
if aot_mode and config.is_fbcode():
if aot_mode and cuda and config.is_fbcode():
# For Meta internal cuda-12, it is recommended to static link cudart
static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"]

Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/codegen/aoti_runtime/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ AOTIRuntimeError AOTInductorModelContainerRun(
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");

auto stream = reinterpret_cast<cudaStream_t>(stream_handle);
auto stream = reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
container->run(
input_handles,
Expand Down Expand Up @@ -208,7 +208,11 @@ AOTIRuntimeError AOTInductorModelRun(
AtenTensorHandle* output_handles) {
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
model->run_impl(input_handles, output_handles, (cudaStream_t)nullptr, nullptr);
model->run_impl(
input_handles,
output_handles,
(torch::aot_inductor::DeviceStreamType)nullptr,
nullptr);
})
}

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def write_wrapper_decl(self):
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
cudaStream_t stream,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
"""
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,9 @@ class Placeholder(enum.Enum):

# A utility function for easier AOTInductor testing
aot_inductor_launcher = """
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif // USE_CUDA
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
Expand Down Expand Up @@ -1219,10 +1221,14 @@ class RAIIModelContainer {
&num_outputs));
std::vector<AtenTensorHandle> output_handles(num_outputs);
#ifdef USE_CUDA
const auto& cuda_stream = c10::cuda::getCurrentCUDAStream();
const auto stream_id = cuda_stream.stream();
AOTInductorStreamHandle stream_handle =
reinterpret_cast<AOTInductorStreamHandle>(stream_id);
#else // !USE_CUDA
AOTInductorStreamHandle stream_handle = nullptr;
#endif
AOTIProxyExecutorHandle proxy_executor_handle = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
// in model.so, and should not refer to any aten/c10 headers except the stable
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.

#ifdef USE_CUDA

// FIXME: Currently, CPU and CUDA backend are mutually exclusive.
// This is a temporary workaround. We need a better way to support
// multi devices.

#include <cuda.h>
#include <cuda_runtime_api.h>

#define AOTI_RUNTIME_CUDA_CHECK(EXPR) \
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
do { \
const cudaError_t code = EXPR; \
const char* msg = cudaGetErrorString(code); \
Expand All @@ -16,3 +23,29 @@
std::string("CUDA error: ") + std::string(msg)); \
} \
} while (0)

namespace torch {
namespace aot_inductor {

using DeviceStreamType = cudaStream_t;

} // namespace aot_inductor
} // namespace torch

#else // !USE_CUDA

#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
bool ok = EXPR; \
if (!ok) { \
throw std::runtime_error("CPU runtime error"); \
}

namespace torch {
namespace aot_inductor {

using DeviceStreamType = void*;

} // namespace aot_inductor
} // namespace torch

#endif // USE_CUDA
42 changes: 33 additions & 9 deletions torch/csrc/inductor/aoti_runtime/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// in model.so, and should not refer to any aten/c10 headers except the stable
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.
#include <torch/csrc/inductor/aoti_runtime/cuda_utils.h>
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
Expand Down Expand Up @@ -112,18 +112,23 @@ class AOTInductorModelBase {
: inputs_info_(num_inputs),
outputs_info_(num_outputs),
constants_info_(num_constants),
cubin_dir_(cubin_dir) {
AOTI_RUNTIME_CUDA_CHECK(cudaGetDevice(&device_idx_));
cubin_dir_(cubin_dir),
device_idx_(-1) {
#ifdef USE_CUDA
AOTI_RUNTIME_DEVICE_CHECK(cudaGetDevice(&device_idx_));
#endif // USE_CUDA
}

~AOTInductorModelBase() {
#ifdef USE_CUDA
if (run_finished_) {
auto code = cudaEventDestroy(*run_finished_);
if (code != cudaSuccess) {
std::cerr << "Failed to destroy CUDA event in AOTInductor model: "
<< cudaGetErrorString(code) << std::endl;
}
}
#endif // USE_CUDA
}

AOTInductorModelBase(AOTInductorModelBase&&) = delete;
Expand All @@ -139,17 +144,24 @@ class AOTInductorModelBase {
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
cudaStream_t stream,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor) {
#ifdef USE_CUDA
if (!run_finished_) {
cudaEvent_t run_finished;
AOTI_RUNTIME_CUDA_CHECK(cudaEventCreate(&run_finished));
AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
run_finished_.emplace(run_finished);
}

auto* model = static_cast<Model*>(this);
model->run_impl(input_handles, output_handles, stream, proxy_executor);
AOTI_RUNTIME_CUDA_CHECK(cudaEventRecord(*run_finished_, stream));
AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
#else // !USE_CUDA
run_finished_ = false;
auto* model = static_cast<Model*>(this);
model->run_impl(input_handles, output_handles, stream, proxy_executor);
run_finished_ = true;
#endif // USE_CUDA
}

size_t num_inputs() const {
Expand Down Expand Up @@ -230,6 +242,7 @@ class AOTInductorModelBase {

/// Returns true if the model is complete.
bool is_finished() {
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model CUDA event was not initialized"};
}
Expand All @@ -244,15 +257,20 @@ class AOTInductorModelBase {
throw std::runtime_error(
std::string("The model did not finish successfully. Error: ") +
cudaGetErrorString(cudaGetLastError()));
#else // !USE_CUDA
return run_finished_;
#endif // USE_CUDA
}

/// Synchronizes completion event.
void wait_for_completion() {
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model CUDA event was not initialized"};
throw std::runtime_error{"Model event was not initialized"};
}

AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_));
AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_));
#endif // USE_CUDA
}

protected:
Expand Down Expand Up @@ -374,7 +392,11 @@ class AOTInductorModelBase {

// Record if the model finishes an inference run so that its owning
// AOTModelContainer can re-use this instance.
#ifdef USE_CUDA
std::optional<cudaEvent_t> run_finished_;
#else // !USE_CUDA
bool run_finished_;
#endif

// Generated model uses this device index to create CUDA guards.
int device_idx_;
Expand Down Expand Up @@ -423,7 +445,7 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
cudaStream_t stream,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor);

static std::unique_ptr<AOTInductorModel> Create(
Expand All @@ -433,6 +455,7 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
}
};

#ifdef USE_CUDA
class AOTICudaStreamGuard {
public:
AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) {
Expand All @@ -449,6 +472,7 @@ class AOTICudaStreamGuard {
private:
std::unique_ptr<void, std::function<void(void*)>> guard_;
};
#endif // USE_CUDA

} // namespace aot_inductor
} // namespace torch
Loading

0 comments on commit da63c7f

Please sign in to comment.