Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move pow2_decay_with_linear_warmup kernel to phi #53741

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"

Expand Down Expand Up @@ -78,12 +76,7 @@ When step_num > total_steps, lr = end_lr
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOp,
ops::Pow2DecayWithLinearWarmupOpMaker);
REGISTER_OP_CPU_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::CPUContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::CPUContext, float>);
125 changes: 0 additions & 125 deletions paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"

REGISTER_OP_CUDA_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::GPUContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::GPUContext, float>);
PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
CPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
25 changes: 25 additions & 0 deletions paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"

PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
GPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
110 changes: 110 additions & 0 deletions paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

template <typename T, typename AttrT>
struct Pow2DecayWithLinearWarmupFunctor {
template <typename U>
using RestrictPtr = U* PADDLE_RESTRICT;

public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr<T> lr,
RestrictPtr<int64_t> step,
size_t warmup_steps,
size_t total_steps,
AttrT base_lr,
AttrT end_lr)
: lr_(lr),
step_(step),
warmup_steps_(warmup_steps),
total_steps_(total_steps),
base_lr_(base_lr),
end_lr_(end_lr) {}

HOSTDEVICE void operator()(size_t) const {
size_t step = static_cast<size_t>(*step_) + 1;
*step_ = static_cast<int64_t>(step);
if (step <= warmup_steps_) {
auto new_lr = static_cast<double>(step) / warmup_steps_ * base_lr_;
*lr_ = static_cast<T>(new_lr);
} else if (step < total_steps_) {
auto factor = 1 - static_cast<double>(step - warmup_steps_) /
(total_steps_ - warmup_steps_);
auto new_lr =
static_cast<double>(base_lr_ - end_lr_) * (factor * factor) + end_lr_;
*lr_ = static_cast<T>(new_lr);
} else {
*lr_ = static_cast<T>(end_lr_);
}
}

private:
RestrictPtr<T> lr_;
RestrictPtr<int64_t> step_;
size_t warmup_steps_;
size_t total_steps_;
AttrT base_lr_;
AttrT end_lr_;
};

template <typename T, typename Context>
void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
const DenseTensor& lr,
const DenseTensor& step,
int64_t warmup_steps,
int64_t total_steps,
float base_lr,
float end_lr,
DenseTensor* lr_out,
DenseTensor* step_out) {
PADDLE_ENFORCE_EQ(&lr,
lr_out,
phi::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_EQ(&step,
step_out,
phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ(
step.IsInitialized(),
true,
phi::errors::InvalidArgument("Input(Step) must be initialized."));

PADDLE_ENFORCE_LE(warmup_steps,
total_steps,
phi::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));

auto* lr_data = lr_out->data<T>();
auto* step_data = step_out->data<int64_t>();
phi::funcs::ForRange<Context> for_range(dev_ctx, 1);
using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data,
step_data,
static_cast<size_t>(warmup_steps),
static_cast<size_t>(total_steps),
static_cast<AttrT>(base_lr),
static_cast<AttrT>(end_lr));
for_range(functor);
}
} // namespace phi
Loading