Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Fix logical_right_shift op bug #1105

Merged
merged 10 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cinn/hlir/op/contrib/logical_right_shift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/op/op_util.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/schedule.h"
Expand Down Expand Up @@ -113,7 +114,7 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(const framework::NodeAt

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(logical_right_shift_compute,
framework::GetInjectiveScheduleFunc(output_shapes, target),
GetInjectiveScheduleFunc(output_shapes, target),
"strategy.logical_right_shift.x86",
1);
return strategy;
Expand Down
2 changes: 1 addition & 1 deletion cinn/runtime/cpu/host_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ inline int FN_INT32(clz)(int x) { return __builtin_clz(x); }

inline int FN_INT32(popc)(int x) { return __builtin_popcount(x); }

inline int FN_INT32(logical_right_shift)(int x, int y) { return (x >> y) & ~(((0x1 << 31) >> y) << 1); }
inline int FN_INT32(logical_right_shift)(int x, int y) { return ((unsigned int)x >> y); }

#undef FN_INT32

Expand Down
2 changes: 1 addition & 1 deletion cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; }
__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; }
__device__ inline int FN_INT32(clz)(int a) { return __clz(a); }
__device__ inline int FN_INT32(popc)(int a) { return __popc(a); }
__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (a >> b) & ~(((0x1 << 31) >> b) << 1); }
__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return ((unsigned int)a >> b); }

// *************************************************************** //

Expand Down