-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[in progress] Expose API for parameter server #1039
Changes from all commits
32b28c6
8ff42cd
647f9a0
cee9944
8015171
eaf19f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from py_paddle import swig_paddle as api | ||
|
||
#import pudb;pudb.set_trace() | ||
|
||
|
||
def main(): | ||
api.initPaddle("--nics=lo0", "--port=7164", "--ports_num=1", | ||
"--num_gradient_servers=1", "--comment=paddle_pserver") | ||
pserver = api.ParameterServer.createParameterServer() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我没有太看明白这里的API设计。我以为这里是想通过调用Python API启动一个parameter server process? 如果是,那么是不是应该把 L21到L26简化为,比如:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对,这个后面正在修改,目前长这个样子是因为历史遗留问题,initPaddle实际上是去初始化各种gflags,后面的版本已经改成proto配置了,见#1051 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow Python code convention and rename module/package name |
||
pserver.init() | ||
pserver.start() | ||
pserver.join() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to source file naming convention and rename this to be |
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "PaddleAPI.h" | ||
|
||
#include "PaddleAPIPrivate.h" | ||
|
||
ParameterServer::ParameterServer() : m(new ParameterServerPrivate()) {} | ||
|
||
ParameterServer* ParameterServer::createParameterServer() { | ||
auto pServer = new ParameterServer(); | ||
pServer->m->pServerUtil.reset(new paddle::PServerUtil()); | ||
return pServer; | ||
} | ||
|
||
ParameterServer::~ParameterServer() { delete m; } | ||
|
||
void ParameterServer::init() { m->pServerUtil->init(); } | ||
|
||
void ParameterServer::start() { m->pServerUtil->start(); } | ||
|
||
void ParameterServer::join() { m->pServerUtil->join(); } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to source file naming convention and rename this to be |
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "PServerUtil.h" | ||
|
||
namespace paddle { | ||
|
||
void PServerUtil::init() { | ||
// round robin to load balance RDMA server ENGINE | ||
std::vector<std::string> devices; | ||
int rdmaCpu = 0; | ||
int onlineCpus = rdma::numCpus(); | ||
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse; | ||
|
||
if (FLAGS_nics.empty()) { | ||
pservers_.resize(numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
if (FLAGS_rdma_tcp == "rdma") { | ||
pservers_[i].reset( | ||
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i].reset(new ParameterServer2(std::string(), FLAGS_port + i)); | ||
} | ||
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server" | ||
<< FLAGS_port + i; | ||
} | ||
} else { | ||
str::split(FLAGS_nics, ',', &devices); | ||
pservers_.resize(devices.size() * numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
for (size_t j = 0; j < devices.size(); ++j) { | ||
if (FLAGS_rdma_tcp == "rdma") { | ||
pservers_[i * devices.size() + j].reset(new ParameterServer2( | ||
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i * devices.size() + j].reset( | ||
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i)); | ||
} | ||
CHECK(pservers_[i * devices.size() + j]->init()) | ||
<< "Fail to initialize parameter server" << devices[j] | ||
<< FLAGS_port + i; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void PServerUtil::start() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
int i = 0; | ||
for (const auto &pserver : pservers_) { | ||
LOG(INFO) << "pserver started : " << i; | ||
pserver->start(); | ||
i++; | ||
} | ||
} | ||
|
||
void PServerUtil::join() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
for (const auto &pserver : pservers_) { | ||
pserver->join(); | ||
} | ||
} | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to source file naming convention and rename this to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,代码规范问题,会在分开的两个pr中进行,这个pr先关掉。 |
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "ParameterServer2.h" | ||
#include "RDMANetwork.h" | ||
#include "paddle/utils/StringUtil.h" | ||
|
||
namespace paddle { | ||
|
||
class PServerUtil { | ||
public: | ||
void init(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. init => 构造函数. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里如果方便的话,可以考虑把GFLAGS提取出来,变成函数的参数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这里是这么想的,变成参数 |
||
void start(); | ||
void join(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DISABLE_COPY(PServerUtil); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 析构的时候调用join. |
||
private: | ||
std::vector<std::shared_ptr<ParameterServer2>> pservers_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来,std::vector<std::unique_ptr< ParameterServer2 >> pservers_; |
||
}; | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,63 +16,16 @@ limitations under the License. */ | |
#include "paddle/utils/StringUtil.h" | ||
#include "paddle/utils/Util.h" | ||
|
||
#include "ParameterServer2.h" | ||
#include "RDMANetwork.h" | ||
#include "paddle/utils/Flags.h" | ||
#include "PServerUtil.h" | ||
|
||
using namespace paddle; // NOLINT | ||
|
||
int main(int argc, char** argv) { | ||
initMain(argc, argv); | ||
|
||
std::vector<std::string> devices; | ||
std::vector<std::shared_ptr<ParameterServer2>> pservers; | ||
|
||
// round robin to loadbalance RDMA server ENGINE | ||
int rdmaCpu = 0; | ||
int onlineCpus = rdma::numCpus(); | ||
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse; | ||
if (FLAGS_nics.empty()) { | ||
pservers.resize(numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
if (FLAGS_rdma_tcp == "rdma") { | ||
pservers[i].reset( | ||
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i)); | ||
} | ||
CHECK(pservers[i]->init()) << "Fail to initialize parameter server" | ||
<< FLAGS_port + i; | ||
LOG(INFO) << "pserver started : " << FLAGS_port + i; | ||
pservers[i]->start(); | ||
} | ||
} else { | ||
str::split(FLAGS_nics, ',', &devices); | ||
pservers.resize(devices.size() * numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
for (size_t j = 0; j < devices.size(); ++j) { | ||
if (FLAGS_rdma_tcp == "rdma") { | ||
pservers[i * devices.size() + j].reset(new ParameterServer2( | ||
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers[i * devices.size() + j].reset( | ||
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i)); | ||
} | ||
CHECK(pservers[i * devices.size() + j]->init()) | ||
<< "Fail to initialize parameter server" << devices[j] | ||
<< FLAGS_port + i; | ||
LOG(INFO) << "pserver started : " << devices[j] << ":" | ||
<< FLAGS_port + i; | ||
pservers[i * devices.size() + j]->start(); | ||
} | ||
} | ||
} | ||
|
||
for (auto& pserver : pservers) { | ||
pserver->join(); | ||
} | ||
PServerUtil* pserverUtil = new PServerUtil(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::unique_ptr pservers(new PServerUtil()); 把init实现在构造函数里,把join放到析构函数里的。 pservers->start(); |
||
pserverUtil->init(); | ||
pserverUtil->start(); | ||
pserverUtil->join(); | ||
|
||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the unused line other than commenting it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,in progress