Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SyncBatchNorm #387

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Changelog for BytePS
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0.2.5.post2 (2020-11)
* add ability to collect PushPull performance data

0.2.4 (2020-06)
------------------
* Fix compatibility issue with tf2 + standalone keras
Expand Down
5 changes: 3 additions & 2 deletions byteps/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def init(self, lazy=True):
"""A function that inits BytePS."""
atexit.register(self.shutdown)
if lazy:
return self.C_LIB_CTYPES.byteps_lazy_init()
ret = self.C_LIB_CTYPES.byteps_lazy_init()
else:
return self.C_LIB_CTYPES.byteps_init()
ret = self.C_LIB_CTYPES.byteps_init()
return ret

def shutdown(self):
"""A function that shuts BytePS down."""
Expand Down
7 changes: 6 additions & 1 deletion byteps/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ enum QueueType {
COPYH2D,
COORDINATE_BROADCAST,
BROADCAST,
ALLGATHER,
COORDINATE_ALLGATHER,
QUEUE_NUM_AND_NOT_A_REAL_QUEUE_TYPE_AND_MUST_BE_THE_LAST
};

Expand All @@ -115,7 +117,10 @@ const std::vector<std::string> LogStrings = {"COORDINATE_REDUCE",
"DECOMPRESS",
"COPYH2D",
"COORDINATE_BROADCAST",
"BROADCAST"};
"BROADCAST",
"ALLGATHER",
"COORDINATE_ALLGATHER",
};

class Status {
public:
Expand Down
3 changes: 3 additions & 0 deletions byteps/common/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ void BytePSCommSocket::startListenThread() { // only root starts this in
case BCAST_READY:
BytePSGlobal::GetBroadcastTable()->AddReadyCount(message.key);
break;
case ALLGATHER_READY:
BytePSGlobal::GetAllgatherTable()->AddReadyCount(message.key);
break;
case PUSH_READY:
BytePSGlobal::GetPushTable()->AddReadyCount(message.key);
break;
Expand Down
2 changes: 2 additions & 0 deletions byteps/common/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ enum BytePSCommSignal {
REDUCE_READY,
PCIE_REDUCE_READY,
BCAST_READY,
ALLGATHER_READY,
PUSH_READY,
DO_REDUCE,
DO_BROADCAST,
DO_ALLGATHER,
DO_GROUP,
DO_COPYH2D
};
Expand Down
57 changes: 52 additions & 5 deletions byteps/common/core_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ bool RunCoordinateLoopOnce(QueueType this_op) {
comm = BytePSGlobal::GetNccl()->GetSignalComm();
break;
}
case COORDINATE_ALLGATHER: {
sig = ALLGATHER_READY;
comm = BytePSGlobal::GetNccl()->GetSignalComm();
break;
}
case COORDINATE_PUSH: {
sig = PUSH_READY;
comm = BytePSGlobal::GetBasicComm();
Expand Down Expand Up @@ -189,9 +194,20 @@ bool RunCoordinateLoopOnce(QueueType this_op) {

inline void PostNcclCalls(
std::shared_ptr<byteps::common::TensorTableEntry> task, QueueType this_op) {
BPS_CHECK(this_op == REDUCE || this_op == BROADCAST)
<< "Only REDUCE and BROADCAST use NCCL.";
auto tensor = (this_op == REDUCE) ? task->tensor : task->output;
BPS_CHECK(this_op == REDUCE || this_op == BROADCAST || this_op == ALLGATHER)
<< "Only REDUCE, BROADCAST and ALLGATHER use NCCL.";

decltype(task->tensor) tensor;

switch (this_op) {
case REDUCE:
case ALLGATHER: {
tensor = task->tensor;
break;
}
default:
tensor = task->output;
}
BPS_CHECK(tensor);
BPS_CHECK_EQ(0, tensor->size() % tensor->shape().num_elements());

Expand All @@ -203,6 +219,7 @@ inline void PostNcclCalls(
if (task->device == CPU_DEVICE_ID) {
p = (char *)(task->gpu_ptr) + offset;
}
auto out_p = (char *)(task->output->data()) + offset;

auto nccl_dtype = getNcclDataType(tensor->dtype());

Expand All @@ -213,6 +230,7 @@ inline void PostNcclCalls(
auto nccl_size = nccl->GetSize();
auto nccl_rank = nccl->GetRank(key, this_op);

auto num_elem_all = len / unit_len;
auto num_elem_per_gpu = len / nccl_size / unit_len;
auto left_elem = (len / unit_len) - (num_elem_per_gpu * nccl_size);
if (BytePSGlobal::IsUsingReduce()) {
Expand Down Expand Up @@ -251,6 +269,12 @@ inline void PostNcclCalls(
(ncclRedOp_t)ncclSum, (int)nccl_root,
(ncclComm_t)nccl_comm, (cudaStream_t)nccl_stream));
}
} else if (this_op == ALLGATHER) {
BPS_CHECK(task->device != CPU_DEVICE_ID);
NCCLCHECK(ncclAllGather(
(const void *)(p),
(void *)out_p, (size_t)num_elem_all, (ncclDataType_t)nccl_dtype,
(ncclComm_t)nccl_comm, (cudaStream_t)nccl_stream));
} else {
if (num_elem_per_gpu) {
NCCLCHECK(ncclAllGather(
Expand All @@ -275,7 +299,7 @@ bool RunRootNcclLoopOnce() {
BPS_CHECK_EQ(rank, root);

int nccl_size = BytePSGlobal::GetNccl()->GetSize();
QueueType nccl_ops[] = {REDUCE, BROADCAST};
QueueType nccl_ops[] = {REDUCE, BROADCAST, ALLGATHER};

auto nccl_entry = std::make_shared<NcclGroupEntry>();
auto &tasks = nccl_entry->tasks;
Expand All @@ -294,8 +318,22 @@ bool RunRootNcclLoopOnce() {

if (nccl_size > 1) {
// notify non-root devices
BytePSCommSignal sig;
switch (this_op) {
case REDUCE:
sig = DO_REDUCE;
break;
case BROADCAST:
sig = DO_BROADCAST;
break;
case ALLGATHER:
sig = DO_ALLGATHER;
break;
default:
BPS_CHECK(0) << "unsupported operation: " << this_op;
}
struct BytePSCommMsg msg = {
rank, (this_op == REDUCE) ? DO_REDUCE : DO_BROADCAST, task->key};
rank, sig, task->key};
signal_comm->broadcastSignal(&msg, sizeof(BytePSCommMsg));
PostNcclCalls(task, this_op);
}
Expand Down Expand Up @@ -337,6 +375,8 @@ bool RunNonRootNcclLoopOnce() {
QueueType this_op = REDUCE;
if (msg.signal == DO_BROADCAST) {
this_op = BROADCAST;
} else if (msg.signal == DO_ALLGATHER) {
this_op = ALLGATHER;
} else {
BPS_CHECK_EQ(msg.signal, DO_REDUCE) << msg.signal << ", " << DO_REDUCE;
}
Expand Down Expand Up @@ -752,6 +792,13 @@ bool RunNonRootCopyHost2DeviceLoopOnce() {
return true;
}

void CoordinateAllgatherLoop() {
while (RunCoordinateLoopOnce(COORDINATE_ALLGATHER) &&
!BytePSGlobal::ShouldShutdown()) {
}
BytePSGlobal::ReportThreadFinish();
}

void CoordinateReduceLoop() {
while (RunCoordinateLoopOnce(COORDINATE_REDUCE) &&
!BytePSGlobal::ShouldShutdown()) {
Expand Down
1 change: 1 addition & 0 deletions byteps/common/core_loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace common {
void CoordinateReduceLoop();

void CoordinateBroadcastLoop();
void CoordinateAllgatherLoop();

void CoordinatePushLoop();

Expand Down
10 changes: 9 additions & 1 deletion byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ std::mutex BytePSGlobal::_encode_mutex;
ReadyTable* BytePSGlobal::_reduce_table;
ReadyTable* BytePSGlobal::_pcie_reduce_table;
ReadyTable* BytePSGlobal::_broadcast_table;
ReadyTable* BytePSGlobal::_allgather_table;
ReadyTable* BytePSGlobal::_push_table;
ReadyTable* BytePSGlobal::_copy_table;
bool BytePSGlobal::_is_using_reduce = false;
Expand Down Expand Up @@ -232,6 +233,9 @@ void BytePSGlobal::Init() {
_reduce_table = new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_REDUCE");
_broadcast_table =
new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_BROADCAST");
_allgather_table =
new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_ALLGATHER");
BPS_LOG(DEBUG) << "Created reduce table, broadcast table and alltagher table";
}

// Configure the reduce strategy
Expand Down Expand Up @@ -370,6 +374,10 @@ void BytePSGlobal::Shutdown() {
delete _broadcast_table;
_broadcast_table = NULL;
}
if (_allgather_table) {
delete _allgather_table;
_allgather_table = NULL;
}
if (_push_table) {
delete _push_table;
_push_table = NULL;
Expand Down Expand Up @@ -701,7 +709,7 @@ std::size_t PushPullSpeed::_limit = 1024;
std::chrono::time_point<std::chrono::system_clock> PushPullSpeed::_last_ts;
bool PushPullSpeed::_initialized = false;
bool PushPullSpeed::_should_record =
getenv("BYTEPS_TELEMETRY_ON") ? atoi(getenv("BYTEPS_TELEMETRY_ON")) : true;
getenv("BYTEPS_TELEMETRY_ON") ? atoi(getenv("BYTEPS_TELEMETRY_ON")) : false;

void PushPullSpeed::RecordSpeed(std::shared_ptr<TensorTableEntry> task) {
std::lock_guard<std::mutex> lock(_mtx);
Expand Down
2 changes: 2 additions & 0 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class BytePSGlobal {
static ReadyTable* GetReduceTable() { return _reduce_table; }
static ReadyTable* GetPcieReduceTable() { return _pcie_reduce_table; }
static ReadyTable* GetBroadcastTable() { return _broadcast_table; }
static ReadyTable* GetAllgatherTable() { return _allgather_table; }
static ReadyTable* GetPushTable() { return _push_table; }

// reduce strategies
Expand Down Expand Up @@ -187,6 +188,7 @@ class BytePSGlobal {
static ReadyTable* _reduce_table;
static ReadyTable* _pcie_reduce_table;
static ReadyTable* _broadcast_table;
static ReadyTable* _allgather_table;
static ReadyTable* _push_table;

// (key, ready_signal_count) pair, only valid for non-root device
Expand Down
Loading