Skip to content

Commit

Permalink
Ywt/refactor max pool2d (DeepLink-org#396)
Browse files Browse the repository at this point in the history
* refactor maxpool2d and fix the bug
  • Loading branch information
yewentao256 authored Sep 15, 2023
1 parent 9f2679a commit 31b62f6
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 219 deletions.
9 changes: 1 addition & 8 deletions impl/camb/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,9 @@
'max_pool2d': dict(
name=["max_pool2d"],
para=dict(
# camb kernel only support dilation == 1
dilation=[Skip((4, 3)), Skip((2, 3)), Skip(2)],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(Dtype.float32), Skip(Dtype.float16)],
},
]
),
),

'max_pool2d_return_indices': dict(
Expand Down
16 changes: 1 addition & 15 deletions impl/camb/diopi_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const char* DiopiDataType::dataTypeStr(diopiDtype_t dtype) {

DiopiTensor::DiopiTensor(const diopiTensorHandle_t& tensor) : tensor_(tensor) {
if (tensor_ != nullptr) {
DIOPI_CHECK_ABORT(this->device() == diopiDevice_t::diopi_device, "%s", "tensor_ is not on camb device.");
diopiSize_t diopiShape;
diopiSize_t diopiStride;
diopiDtype_t diopiDtype;
Expand Down Expand Up @@ -260,22 +261,7 @@ const void* DiopiTensor::data() const {
return p;
}

diopiTensorHandle_t DiopiTensor::tensorHandle() {
if (this->defined()) {
DIOPI_CHECK_ABORT(this->device() == diopiDevice_t::diopi_device, "%s", "tensor_ is not on camb device.");
}
return tensor_;
}

diopiConstTensorHandle_t DiopiTensor::tensorHandle() const {
if (this->defined()) {
DIOPI_CHECK_ABORT(this->device() == diopiDevice_t::diopi_device, "%s", "tensor_ is not on camb device.");
}
return tensor_;
}

// other funcs

DiopiTensor makeTensor(diopiContextHandle_t ctx, const diopiScalar_t* pScalar) {
diopiTensorHandle_t tensor = nullptr;
std::vector<int64_t> shape{1};
Expand Down
4 changes: 2 additions & 2 deletions impl/camb/diopi_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ class DiopiTensor final {
void* data();
const void* data() const;

diopiTensorHandle_t tensorHandle();
diopiConstTensorHandle_t tensorHandle() const;
diopiTensorHandle_t tensorHandle() { return tensor_; }
diopiConstTensorHandle_t tensorHandle() const { return tensor_; }

bool isSame(DiopiTensor t) { return this->tensorHandle() == t.tensorHandle(); }

Expand Down
Loading

0 comments on commit 31b62f6

Please sign in to comment.