Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dipu]add ascend profiler #476

Merged
merged 7 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion dipu/tests/python/unittests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@
import torch_dipu
import torchvision.models as models
from torch.profiler import profile, ProfilerActivity
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn
from tests.python.utils.local_eviron import local_eviron
import torch._dynamo as dynamo
import subprocess

def check_string_in_directory(directory, search_string):
grep_process = subprocess.Popen(["grep", "-r", search_string, directory], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, _ = grep_process.communicate()
if output:
return True
else:
return False


class TestProfiler(TestCase):
Expand Down Expand Up @@ -50,5 +60,42 @@ def test_profiler(self):

prof.export_chrome_trace("./dipu_resnet18_profiler.json")

@onlyOn("NPU")
def test_aot_profiler(self):
x = torch.randn(3, 4).cuda()
y = torch.randn(3, 4).cuda()
path = "./results/aot/"
with torch_dipu.profiler.NativeProfile(path, True):
x.add_(y)

self.assertTrue(check_string_in_directory(path, "test_profiler.py"))
self.assertTrue(check_string_in_directory(path, "aten::add_"))
self.assertTrue(check_string_in_directory(path, "Add"))


@onlyOn("NPU")
def test_dicp_profiler(self):
def fn(x):
y = torch.nn.functional.softmax(x, -1)
y = y * 5
y = torch.relu(y)
return y

opt_model = torch.compile(fn, backend='ascendgraph')
input = torch.randn(2, 3).cuda()
# warmup
for _ in range(5):
opt_model(input)
path = "./results/dicp/"
with torch_dipu.profiler.NativeProfile(path, True):
y = opt_model(input)
z = y + y

self.assertTrue(check_string_in_directory(path, "test_profiler.py"))
self.assertTrue(check_string_in_directory(path, "aten::add"))
self.assertTrue(check_string_in_directory(path, "mulrelu"))
self.assertTrue(check_string_in_directory(path, "softmax"))


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions dipu/torch_dipu/csrc_dipu/binding/ExportProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "csrc_dipu/profiler/profiler.h"
#include "csrc_dipu/profiler/profiler_kineto.h"
#include "csrc_dipu/profiler/profiler_python.h"
#include "csrc_dipu/runtime/device/profilerapis.h"
#include "csrc_dipu/runtime/devproxy/deviceproxy.h"

#include "exportapi.h"
Expand Down Expand Up @@ -39,6 +40,9 @@ void exportProfiler(PyObject* module) {
return activities;
});
profile::init();

m.def("_enable_profiler_api", &devapis::enableProfiler);
m.def("_disable_profiler_api", &devapis::disableProfiler);
}

} // namespace dipu
16 changes: 16 additions & 0 deletions dipu/torch_dipu/csrc_dipu/runtime/device/profilerapis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) 2023, DeepLink.
#pragma once

#include <string>

#include "./basedef.h"

namespace dipu {
namespace devapis {

DIPU_WEAK void enableProfiler(const std::string& dump_path, bool call_stack,
bool record_shapes, bool profile_memory);
DIPU_WEAK void disableProfiler();

} // end namespace devapis
} // end namespace dipu
2 changes: 1 addition & 1 deletion dipu/torch_dipu/csrc_dipu/vendor/ascend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ include(FindAscendToolKit)
# it's incorrect, need enhance find cmake to set lib fullpath
set(VENDOR_INCLUDE_DIRS "${ASCEND_TOOLKIT_ROOT}/include" PARENT_SCOPE)
set(VENDOR_LIB_DIRS "${ASCEND_TOOLKIT_ROOT}/lib64" PARENT_SCOPE)
set(DIPU_VENDOR_LIB ascendcl acl_op_compiler hccl PARENT_SCOPE)
set(DIPU_VENDOR_LIB ascendcl acl_op_compiler hccl msprofiler PARENT_SCOPE)


# rewrite vendor header file path if needed
Expand Down
192 changes: 192 additions & 0 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/profilerimpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright (c) 2023, DeepLink.
#include <acl/acl.h>
#include <acl/acl_op.h>
#include <acl/acl_op_compiler.h>
#include <acl/acl_prof.h>
#include <array>
#include <cstdint>

#include <ATen/record_function.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/profiler/util.h>

#include <csrc_dipu/runtime/device/profilerapis.h>
#include <csrc_dipu/vendor/vendorapi.h>

extern "C" aclError aclprofSetStampTagName(void* stamp, const char* tagName,
uint16_t len);
extern "C" aclError aclprofSetStampTraceMessage(void* stamp, const char* msg,
uint32_t msgLen);
extern "C" aclError aclprofSetStampCallStack(void* stamp, const char* callStack,
uint32_t len);

namespace dipu {
namespace devapis {

static const uint64_t kNpuEvents = 431;
static const uint64_t kAicoreMetrics = 1;

class AscendProfiler {
public:
AscendProfiler(const AscendProfiler&) = delete;
AscendProfiler& operator=(const AscendProfiler&) = delete;

// AscendProfiler designed as a singleton
static AscendProfiler& instance();

void enableProfiler(const std::string& dump_path, bool call_stack,
bool record_shapes, bool profile_memory);
void disableProfiler();

std::unique_ptr<at::ObserverContext> startRecordEvent(
const at::RecordFunction& fn);
void finishRecordEvent(const at::RecordFunction& fn,
at::ObserverContext* context) const;

private:
AscendProfiler() = default;
static void recordCallStack(void* stamp, int64_t sequence_num,
at::RecordScope scope);

bool enable_ = false;
aclprofConfig* config_ = nullptr;
bool call_stack_ = false;
bool record_shapes_ = false;
bool profile_memory_ = false;
};

AscendProfiler& AscendProfiler::instance() {
static AscendProfiler profiler;
return profiler;
}

void AscendProfiler::enableProfiler(const std::string& dump_path,
bool call_stack, bool record_shapes,
bool profile_memory) {
if (enable_) {
DIPU_LOGW("ascend profiler has already enabled");
return;
}

call_stack_ = call_stack;
record_shapes_ = record_shapes;
profile_memory_ = profile_memory;
int32_t device_index = 0;
DIPU_CALLACLRT(aclrtGetDevice(&device_index));

std::array<uint32_t, 1> device_ids = {static_cast<uint32_t>(device_index)};
aclprofAicoreEvents* events = nullptr;
config_ = aclprofCreateConfig(
device_ids.data(), device_ids.size(),
static_cast<aclprofAicoreMetrics>(kAicoreMetrics), events, kNpuEvents);
TORCH_CHECK(config_ != nullptr,
"aclprofCreateConfig fail, device_index = ", device_index,
"npu_event = ", kNpuEvents, "aicore_metrics = ", kAicoreMetrics);

DIPU_CALLACLRT(aclrtSynchronizeDevice());
DIPU_CALLACLRT(aclprofInit(dump_path.c_str(), dump_path.size()));
DIPU_CALLACLRT(aclprofStart(config_));

at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
return AscendProfiler::instance().startRecordEvent(fn);
},
[](const at::RecordFunction& fn, at::ObserverContext* ctx) {
AscendProfiler::instance().finishRecordEvent(fn, ctx);
}));
enable_ = true;
}

void AscendProfiler::disableProfiler() {
if (!enable_) {
DIPU_LOGW("ascend profiler has already disabled");
return;
}

DIPU_CALLACLRT(aclrtSynchronizeDevice());
at::clearThreadLocalCallbacks();
DIPU_CALLACLRT(aclprofStop(config_));
DIPU_CALLACLRT(aclprofFinalize());
enable_ = false;
}

struct AscendObserverContext : public at::ObserverContext {
AscendObserverContext(void* d, uint32_t n) : data(d), id(n) {}

void* data = nullptr;
uint32_t id = 0;
};

std::unique_ptr<at::ObserverContext> AscendProfiler::startRecordEvent(
const at::RecordFunction& fn) {
if (!enable_) {
DIPU_LOGW("ascend profiler not enabled, ignore record event");
return std::unique_ptr<AscendObserverContext>();
}

void* stamp = aclprofCreateStamp();
TORCH_CHECK(stamp != nullptr, "aclprofCreateStamp fail",
", error msg = ", aclGetRecentErrMsg());
static const std::string tag_name = "torch_op";
DIPU_CALLACLRT(
aclprofSetStampTagName(stamp, tag_name.c_str(), tag_name.size()));
DIPU_CALLACLRT(
aclprofSetStampTraceMessage(stamp, fn.name(), strlen(fn.name())));

if (call_stack_) {
recordCallStack(stamp, fn.seqNr(), fn.scope());
}

uint32_t range_id = 0;
DIPU_CALLACLRT(aclprofRangeStart(stamp, &range_id));
return std::make_unique<AscendObserverContext>(stamp, range_id);
}

void AscendProfiler::recordCallStack(void* stamp, int64_t sequence_num,
at::RecordScope scope) {
std::string seq_nr = "seq=" + std::to_string(sequence_num);
std::vector<std::string> py_stack;
std::string call_stack_data;

if (scope != at::RecordScope::BACKWARD_FUNCTION) {
auto cs =
torch::profiler::impl::prepareCallstack(torch::jit::currentCallstack());
if (cs.empty()) {
cs = torch::profiler::impl::prepareCallstack(
torch::jit::tracer::pythonCallstack());
}
py_stack = torch::profiler::impl::callstackStr(cs);
call_stack_data = torch::profiler::impl::stacksToStr(py_stack, ";");
} else {
call_stack_data = seq_nr;
}

if (!call_stack_data.empty()) {
DIPU_CALLACLRT(aclprofSetStampCallStack(stamp, call_stack_data.c_str(),
call_stack_data.size()));
}
}

void AscendProfiler::finishRecordEvent(const at::RecordFunction& fn,
at::ObserverContext* context) const {
if (!enable_) {
DIPU_LOGW("ascend profiler not enabled, ignore record event");
return;
}

auto* ctx_ptr = static_cast<AscendObserverContext*>(context);
DIPU_CALLACLRT(aclprofRangeStop(ctx_ptr->id));
aclprofDestroyStamp(ctx_ptr->data);
}

void enableProfiler(const std::string& dump_path, bool call_stack,
bool record_shapes, bool profile_memory) {
AscendProfiler::instance().enableProfiler(dump_path, call_stack,
record_shapes, profile_memory);
}

void disableProfiler() { AscendProfiler::instance().disableProfiler(); }

} // end namespace devapis
} // end namespace dipu
16 changes: 8 additions & 8 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/vendorapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>

#include <c10/util/Exception.h>

#include <csrc_dipu/common.h>

namespace dipu {
Expand All @@ -20,14 +22,12 @@ namespace dipu {
} \
}

#define DIPU_CALLACLRT(Expr) \
{ \
TRACK_ACL(#Expr); \
::aclError ret = Expr; \
if (ret != ::ACL_SUCCESS) { \
throw std::runtime_error(std::string("ascend device error:") + \
aclGetRecentErrMsg()); \
} \
#define DIPU_CALLACLRT(Expr) \
{ \
TRACK_ACL(#Expr); \
::aclError ret = Expr; \
TORCH_CHECK(ret == ACL_SUCCESS, "ascend device error, expr = ", #Expr, \
", ret = ", ret, ", error msg = ", aclGetRecentErrMsg()); \
}

using deviceStream_t = aclrtStream;
Expand Down
1 change: 1 addition & 0 deletions dipu/torch_dipu/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .profiler import NativeProfile
29 changes: 27 additions & 2 deletions dipu/torch_dipu/profiler/profiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from torch_dipu import _C
from operator import attrgetter
Expand All @@ -20,7 +21,7 @@
def dipu_kineto_available():
return True

class DIPUProfile(torch.autograd.profiler.profile):
class TorchProfile(torch.autograd.profiler.profile):
def _parse_kineto_results(self, result):
# result.events() has most of the events - PyTorch op-level and device-level events

Expand Down Expand Up @@ -421,4 +422,28 @@ def apply_profiler_patch():
setattr(torch.autograd, '_supported_activities', _C._supported_activities)
setattr(torch.autograd, '_add_metadata_json', _C._add_metadata_json)
setattr(torch.autograd.profiler_util, '_build_table', dipu_build_table)
torch.autograd.profiler.profile = DIPUProfile
torch.autograd.profiler.profile = TorchProfile


class NativeProfile(object):
def __init__(self, profiler_result_path="./", with_stack=False, record_shapes=False, profile_memory=False):
self.result_path = profiler_result_path
self.with_stack = with_stack
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.entered = False
try:
os.makedirs(self.result_path, exist_ok=True)
except Exception:
raise ValueError("the path of '%s' is invaild." % (self.result_path))

def __enter__(self):
if self.entered:
raise RuntimeError("native profile traces are not reentrant")

self.entered = True
_C._enable_profiler_api(self.result_path, self.with_stack, self.record_shapes, self.profile_memory)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
_C._disable_profiler_api()