Skip to content

Commit

Permalink
Merge branch 'main' into yifeng/fft_c2c
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng authored Jul 4, 2024
2 parents e0c678e + 3316f9f commit ee57739
Show file tree
Hide file tree
Showing 53 changed files with 6,744 additions and 625 deletions.
13 changes: 5 additions & 8 deletions .github/actions/inductor-xpu-e2e-test/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,11 @@ runs:
contains "accuracy,performance" $scenario
$contains_status
if [ "${MODEL_ONLY_NAME}" == "" ];then
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 0 static 8 0 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 1 static 8 1 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 2 static 8 2 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 3 static 8 3 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 4 static 8 4 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 5 static 8 5 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 6 static 8 6 &
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 7 static 8 7 &
xpu_list=($(xpu-smi discovery |grep 'DRM Device: /dev/' |sed 's/.*card//;s/[^0-9].*//' |awk '{print $1 - 1":"NR - 1}'))
for xpu_id in ${xpu_list[*]}
do
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu ${xpu_id/:*} static ${#xpu_list[*]} ${xpu_id/*:} &
done
else
bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 0 static 1 0 ${MODEL_ONLY_NAME} &
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mobilevit_s,pass,pass,pass,pass,pass
nfnet_l0,pass,pass,pass,pass,pass
pit_b_224,pass,pass,pass,pass,pass
pnasnet5large,pass,pass,pass,pass,pass
poolformer_m36,pass,pass,fail_accuracy,pass,pass
poolformer_m36,pass,pass,pass,pass,pass
regnety_002,pass,pass,pass,pass,pass
repvgg_a2,pass,pass,pass,pass,pass
res2net101_26w_4s,pass,pass,pass,pass,pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ mobilenetv3_large_100,pass,pass,fail_accuracy,pass,pass
mobilevit_s,pass,pass,fail_accuracy,pass,pass
nfnet_l0,pass,pass,pass,pass,pass
pit_b_224,pass,pass,pass,pass,pass
pnasnet5large,pass,pass,fail_accuracy,pass,fail_accuracy
poolformer_m36,pass,pass,fail_accuracy,pass,pass
pnasnet5large,pass,pass,pass,pass,fail_accuracy
poolformer_m36,pass,pass,pass,pass,pass
regnety_002,pass,pass,fail_accuracy,pass,pass
repvgg_a2,pass,pass,fail_accuracy,pass,pass
res2net101_26w_4s,pass,pass,fail_accuracy,pass,pass
Expand Down
14 changes: 7 additions & 7 deletions .github/ci_expected_accuracy/inductor_torchbench_inference.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Background_Matting,pass_due_to_skip,pass_due_to_skip,eager_fail_to_run,pass_due_
DALLE2_pytorch,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
LearningToPaint,pass,pass,pass,pass,pass
Super_SloMo,pass,pass,pass,pass,pass
alexnet,eager_two_runs_differ,pass,eager_two_runs_differ,pass,eager_two_runs_differ
alexnet,eager_two_runs_differ,pass,pass,pass,eager_two_runs_differ
basic_gnn_edgecnn,pass,pass,pass,pass,pass
basic_gnn_gcn,pass,pass,pass,pass,pass
basic_gnn_gin,pass,pass,pass,pass,pass
Expand All @@ -20,11 +20,11 @@ detectron2_fasterrcnn_r_101_fpn,pass,eager_fail_to_run,fail_accuracy,eager_fail_
detectron2_fasterrcnn_r_50_c4,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy
detectron2_fasterrcnn_r_50_dc5,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy
detectron2_fasterrcnn_r_50_fpn,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy
detectron2_fcos_r_50_fpn,pass,fail_accuracy,fail_accuracy,pass,fail_accuracy
detectron2_fcos_r_50_fpn,pass,pass,pass,pass,pass
detectron2_maskrcnn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
detectron2_maskrcnn_r_101_c4,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
detectron2_maskrcnn_r_101_fpn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
detectron2_maskrcnn_r_50_c4,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
detectron2_maskrcnn_r_50_c4,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
detectron2_maskrcnn_r_50_fpn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
dlrm,pass,pass,pass,pass,pass
doctr_det_predictor,pass,pass,pass,eager_fail_to_run,pass
Expand Down Expand Up @@ -61,7 +61,7 @@ mnasnet1_0,pass,pass,pass,pass,pass
mobilenet_v2,pass,pass,pass,pass,pass
mobilenet_v2_quantized_qat,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load
mobilenet_v3_large,pass,pass,pass,pass,pass
moco,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
moco,model_fail_to_load,model_fail_to_load,model_fail_to_load,eager_fail_to_run,model_fail_to_load
moondream,pass,pass,pass,pass,pass
nanogpt,pass,pass,pass,pass,pass
nvidia_deeprecommender,pass,pass,pass,pass,pass
Expand Down Expand Up @@ -89,7 +89,7 @@ speech_transformer,pass,pass,pass,pass,pass
squeezenet1_1,pass,fail_accuracy,fail_accuracy,pass,pass
stable_diffusion_text_encoder,pass,pass,pass,pass,pass
stable_diffusion_unet,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
tacotron2,pass,pass,pass,model_fail_to_load,model_fail_to_load
tacotron2,pass,pass,pass,model_fail_to_load,fail_to_run
timm_efficientdet,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load
timm_efficientnet,pass,pass,pass,pass,pass
timm_nfnet,pass,pass,pass,pass,pass
Expand All @@ -98,8 +98,8 @@ timm_resnest,pass,pass,pass,pass,pass
timm_vision_transformer,pass,pass,pass,pass,pass
timm_vision_transformer_large,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
timm_vovnet,pass,pass,pass,pass,pass
torch_multimodal_clip,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
vgg16,eager_two_runs_differ,pass,eager_two_runs_differ,pass,pass
vgg16,eager_two_runs_differ,pass,pass,pass,pass
vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
yolov3,pass,pass,pass,pass,pass
10 changes: 5 additions & 5 deletions .github/ci_expected_accuracy/inductor_torchbench_training.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name,float32,bfloat16,float16,amp_bf16,amp_fp16
torchrec_dlrm,fail_to_run,eager_fail_to_run,eager_fail_to_run,fail_to_run,fail_to_run
torchrec_dlrm,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,fail_to_run
BERT_pytorch,pass,pass,pass,pass,pass
Background_Matting,pass_due_to_skip,pass_due_to_skip,eager_fail_to_run,pass_due_to_skip,eager_fail_to_run
DALLE2_pytorch,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
Expand Down Expand Up @@ -53,15 +53,15 @@ hf_distil_whisper,model_fail_to_load,model_fail_to_load,model_fail_to_load,model
lennard_jones,pass,pass,pass,pass,pass
llama,pass,pass,pass,pass,pass
llama_v2_7b_16h,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
llava,eager_fail_to_run,eager_2nd_run_fail,eager_2nd_run_fail,eager_fail_to_run,eager_fail_to_run
llava,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
maml,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
maml_omniglot,pass,pass,pass,pass,pass
microbench_unbacked_tolist_sum,pass,pass,pass,pass,pass
mnasnet1_0,pass,pass,pass,pass,pass
mobilenet_v2,pass,pass,pass,pass,pass
mobilenet_v2_quantized_qat,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
mobilenet_v3_large,pass,pass,pass,pass,pass
moco,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
moco,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,eager_fail_to_run
moondream,pass,pass,pass,pass,pass
nanogpt,pass,pass,pass,pass,pass
nvidia_deeprecommender,pass,pass,pass,pass,pass
Expand Down Expand Up @@ -91,14 +91,14 @@ stable_diffusion_text_encoder,pass,pass,pass,pass,pass
stable_diffusion_unet,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
tacotron2,fail_to_run,fail_to_run,fail_to_run,fail_to_run,fail_to_run
timm_efficientdet,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load
timm_efficientnet,pass,pass,pass,fail_accuracy,pass
timm_efficientnet,pass,pass,pass,pass,pass
timm_nfnet,pass,pass,pass,pass,pass
timm_regnet,pass,pass,pass,pass,pass
timm_resnest,pass,pass,pass,pass,pass
timm_vision_transformer,pass,pass,pass,pass,pass
timm_vision_transformer_large,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
timm_vovnet,pass,pass,pass,pass,pass
torch_multimodal_clip,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
vgg16,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ
vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
Expand Down
2 changes: 2 additions & 0 deletions .github/scripts/apply_torch_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"https://github.com/pytorch/pytorch/pull/127277",
# [Inductor][Intel GPU] Support reduction split.
"https://github.com/pytorch/pytorch/pull/129120",
# Modify the tolerance level in TIMM benchmark
"https://github.com/pytorch/pytorch/pull/129735",
]
)
parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[])
Expand Down
2 changes: 1 addition & 1 deletion .github/scripts/inductor_xpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ fi
ulimit -n 1048576
ZE_AFFINITY_MASK=${CARD} \
python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${Real_DT} -d ${DEVICE} -n10 --no-skip --dashboard \
${DT_extra} ${Mode_extra} ${Shape_extra} ${partition_flags} ${Model_only_extra} --backend=inductor --timeout=7200 \
${DT_extra} ${Mode_extra} ${Shape_extra} ${partition_flags} ${Model_only_extra} --backend=inductor --timeout=10800 \
--output=${LOG_DIR}/${LOG_NAME}.csv 2>&1 | tee ${LOG_DIR}/${LOG_NAME}_card${CARD}.log
2 changes: 1 addition & 1 deletion .github/workflows/inductor_xpu_e2e_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files
failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true)
if [ ${failed_case} -ne 0 ];then
grep -E "Failed: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log
grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log
exit 1
fi
- name: Upload Inductor XPU E2E Data
Expand Down
28 changes: 14 additions & 14 deletions .github/workflows/inductor_xpu_e2e_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ jobs:
cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files
failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true)
if [ ${failed_case} -ne 0 ];then
grep -E "Failed: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log
grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log
exit 1
fi
- name: Upload Inductor XPU E2E Data
Expand All @@ -260,19 +260,19 @@ jobs:
# Test env
build_url="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
repo="${{ github.repository }}"
TORCH_BRANCH_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_BRANCH_ID }}
TORCH_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_COMMIT_ID }}
DRIVER_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.DRIVER_VERSION }}
BUNDLE_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.BUNDLE_VERSION }}
OS_PRETTY_NAME=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.OS_PRETTY_NAME }}
GCC_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.GCC_VERSION }}
TORCHBENCH_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHBENCH_COMMIT_ID }}
TORCHVISION_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHVISION_COMMIT_ID }}
TORCHAUDIO_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHAUDIO_COMMIT_ID }}
# TORCHTEXT_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHTEXT_COMMIT_ID }}
TRANSFORMERS_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRANSFORMERS_VERSION }}
TIMM_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TIMM_COMMIT_ID }}
TRITON_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRITON_COMMIT_ID }}
TORCH_BRANCH_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_BRANCH_ID }}"
TORCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_COMMIT_ID }}"
DRIVER_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.DRIVER_VERSION }}"
BUNDLE_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.BUNDLE_VERSION }}"
OS_PRETTY_NAME="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.OS_PRETTY_NAME }}"
GCC_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.GCC_VERSION }}"
TORCHBENCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHBENCH_COMMIT_ID }}"
TORCHVISION_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHVISION_COMMIT_ID }}"
TORCHAUDIO_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHAUDIO_COMMIT_ID }}"
# TORCHTEXT_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHTEXT_COMMIT_ID }}"
TRANSFORMERS_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRANSFORMERS_VERSION }}"
TIMM_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TIMM_COMMIT_ID }}"
TRITON_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRITON_COMMIT_ID }}"
# Test status
if [ "${{ needs.Inductor-XPU-E2E-Nightly-Tests.result }}" == "success" ];then
test_status=Success
Expand Down
124 changes: 124 additions & 0 deletions src/ATen/native/xpu/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <ATen/native/xpu/sycl/ActivationHardtanhKernels.h>
#include <ATen/native/xpu/sycl/ActivationLeakyReluKernels.h>
#include <ATen/native/xpu/sycl/ActivationSiluKernels.h>
#include <ATen/native/xpu/sycl/ActivationSoftplusKernels.h>
#include <ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h>
#include <ATen/native/xpu/sycl/ActivationThresholdKernel.h>

namespace at {
Expand Down Expand Up @@ -508,4 +510,126 @@ Tensor& XPUNativeFunctions::leaky_relu_backward_out(
return grad_input;
}

TensorIterator softplus_meta(
const Tensor& self,
const Scalar& beta,
const Scalar& threshold,
Tensor& out) {
return TensorIterator::unary_op(out, self);
}

Tensor XPUNativeFunctions::softplus(
const Tensor& self,
const Scalar& beta,
const Scalar& threshold) {
Tensor out;
auto iter = softplus_meta(self, beta, threshold, out);
native::xpu::softplus_kernel(iter, beta, threshold);
return iter.output();
}

Tensor& XPUNativeFunctions::softplus_out(
const Tensor& self,
const Scalar& beta,
const Scalar& threshold,
Tensor& out) {
auto iter = softplus_meta(self, beta, threshold, out);
native::xpu::softplus_kernel(iter, beta, threshold);
return out;
}

TensorIterator softplus_backward_meta(
const Tensor& grad_output,
const Tensor& self,
const Scalar& beta,
const Scalar& threshold,
Tensor& grad_input) {
return TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
}

Tensor XPUNativeFunctions::softplus_backward(
const Tensor& grad_output,
const Tensor& self,
const Scalar& beta,
const Scalar& threshold) {
Tensor grad_input;
auto iter =
softplus_backward_meta(grad_output, self, beta, threshold, grad_input);
native::xpu::softplus_backward_kernel(iter, beta, threshold);
return iter.output();
}

Tensor& XPUNativeFunctions::softplus_backward_out(
const Tensor& grad_output,
const Tensor& self,
const Scalar& beta,
const Scalar& threshold,
Tensor& grad_input) {
auto iter =
softplus_backward_meta(grad_output, self, beta, threshold, grad_input);
native::xpu::softplus_backward_kernel(iter, beta, threshold);
return grad_input;
}

static inline void softshrink_check(const Scalar& lambd) {
double lamb = lambd.to<double>();
TORCH_CHECK(
lamb >= 0,
"lambda must be greater or equal to 0, but found to be ",
lamb,
".");
}

TensorIterator softshrink_meta(
const Tensor& self,
const Scalar& lambd,
Tensor& out) {
softshrink_check(lambd);
return TensorIterator::unary_op(out, self);
}

Tensor XPUNativeFunctions::softshrink(const Tensor& self, const Scalar& lambd) {
Tensor out;
auto iter = softshrink_meta(self, lambd, out);
native::xpu::softshrink_kernel(iter, lambd);
return iter.output();
}

Tensor& XPUNativeFunctions::softshrink_out(
const Tensor& self,
const Scalar& lambd,
Tensor& out) {
auto iter = softshrink_meta(self, lambd, out);
native::xpu::softshrink_kernel(iter, lambd);
return out;
}

TensorIterator softshrink_backward_meta(
const Tensor& grad_output,
const Tensor& self,
const Scalar& lambd,
Tensor& grad_input) {
return TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
}

Tensor XPUNativeFunctions::softshrink_backward(
const Tensor& grad_output,
const Tensor& self,
const Scalar& lambd) {
Tensor grad_input;
auto iter = softshrink_backward_meta(grad_output, self, lambd, grad_input);
native::xpu::softshrink_backward_kernel(iter, lambd);
return iter.output();
}

Tensor& XPUNativeFunctions::softshrink_backward_out(
const Tensor& grad_output,
const Tensor& self,
const Scalar& lambd,
Tensor& grad_input) {
auto iter = softshrink_backward_meta(grad_output, self, lambd, grad_input);
native::xpu::softshrink_backward_kernel(iter, lambd);
return grad_input;
}

} // namespace at
Loading

0 comments on commit ee57739

Please sign in to comment.