Skip to content

Commit

Permalink
Merge pull request #980 from reyoung/feature/add_const_in_gradient_ma…
Browse files Browse the repository at this point in the history
…chine_eval

Add const to GradientMachine::eval
  • Loading branch information
gangliao authored Dec 22, 2016
2 parents c1b294a + 4d5a0b0 commit db82a0e
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/GradientMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ class GradientMachine {
/**
* Create an evaluator which can be used for eval()
*/
virtual Evaluator* makeEvaluator() = 0;
virtual Evaluator* makeEvaluator() const = 0;

/**
* evaluate using the given evaluator
*/
virtual void eval(Evaluator* evaluator) = 0;
virtual void eval(Evaluator* evaluator) const = 0;

std::vector<ParameterPtr>& getParameters() { return parameters_; }

Expand Down
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/MultiGradientMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ void MultiGradientMachine::finish() {
}
}

Evaluator* MultiGradientMachine::makeEvaluator() {
Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator();
}

void MultiGradientMachine::eval(Evaluator* evaluator) {
void MultiGradientMachine::eval(Evaluator* evaluator) const {
for (auto& thread : threads_) {
SetDevice device(thread->getDeviceId());
thread->getGradientMachine()->eval(evaluator);
Expand Down
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/MultiGradientMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ class MultiGradientMachine : public GradientMachine {

virtual void finish();

virtual Evaluator* makeEvaluator();
virtual Evaluator* makeEvaluator() const;

virtual void eval(Evaluator* evaluator);
virtual void eval(Evaluator* evaluator) const;

bool useGpu() const { return useGpu_; }

Expand Down
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/MultiNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class MultiCombinedEvaluator : public Evaluator {
std::vector<std::unique_ptr<Evaluator>> evaluators_;
};

Evaluator* MultiNetwork::makeEvaluator() {
Evaluator* MultiNetwork::makeEvaluator() const {
MultiCombinedEvaluator* multiCombinedEvaluator = new MultiCombinedEvaluator();
for (size_t i = 0; i < subNetworks_.size(); i++) {
std::unique_ptr<Evaluator> evaluator(subNetworks_[i]->makeEvaluator());
Expand All @@ -180,6 +180,6 @@ Evaluator* MultiNetwork::makeEvaluator() {
return multiCombinedEvaluator;
}

void MultiNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); }
void MultiNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }

} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/MultiNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class MultiNetwork : public NeuralNetwork {

virtual void onPassEnd();

virtual Evaluator* makeEvaluator();
virtual Evaluator* makeEvaluator() const;

virtual void eval(Evaluator* evaluator);
virtual void eval(Evaluator* evaluator) const;

const std::vector<std::unique_ptr<NeuralNetwork>>& getSubNetworks() const {
return subNetworks_;
Expand Down
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/NeuralNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ class CombinedEvaluator : public Evaluator {
std::vector<std::unique_ptr<Evaluator>> evaluators_;
};

Evaluator* NeuralNetwork::makeEvaluator() {
Evaluator* NeuralNetwork::makeEvaluator() const {
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
auto subModelConfig = std::find_if(config_.sub_models().begin(),
config_.sub_models().end(),
Expand Down Expand Up @@ -383,7 +383,7 @@ Evaluator* NeuralNetwork::makeEvaluator() {
return combinedEvaluator;
}

void NeuralNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); }
void NeuralNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }

void NeuralNetwork::setOutputGrad(const std::vector<Argument>& args) {
CHECK_GE(outputLayers_.size(), args.size());
Expand Down
4 changes: 2 additions & 2 deletions paddle/gserver/gradientmachines/NeuralNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class NeuralNetwork : public GradientMachine {

virtual void onPassEnd();

virtual Evaluator* makeEvaluator();
virtual Evaluator* makeEvaluator() const;

virtual void eval(Evaluator* evaluator);
virtual void eval(Evaluator* evaluator) const;
virtual void resetState();
virtual void setOutputGrad(const std::vector<Argument>& args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ void RecurrentGradientMachine::forwardBackward(
LOG(FATAL) << "should not use this function";
}

void RecurrentGradientMachine::eval(Evaluator* evaluator) {
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
// call printers frame by frame
for (int i = 0; i < maxSequenceLength_; ++i) {
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/gradientmachines/RecurrentGradientMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class RecurrentGradientMachine : public NeuralNetwork {
const UpdateCallback& callback);

virtual void resetState() {}
virtual void eval(Evaluator* evaluator);
virtual void eval(Evaluator* evaluator) const;

const std::vector<int>& getParameterIds() { return parameterIds_; }

Expand Down

0 comments on commit db82a0e

Please sign in to comment.