From 0f9317b16da926080788a09986ab48b415256d04 Mon Sep 17 00:00:00 2001 From: milakov Date: Sat, 12 Jul 2014 18:55:09 +0400 Subject: [PATCH] Momentum (in training) implemented --- nnforge/cuda/network_updater_cuda.cu | 111 +++++++++++++++++----- nnforge/cuda/network_updater_cuda.h | 7 +- nnforge/network_trainer.cpp | 2 + nnforge/network_trainer.h | 1 + nnforge/network_trainer_sdlm.cpp | 3 +- nnforge/network_trainer_sgd.cpp | 3 +- nnforge/network_updater.cpp | 5 +- nnforge/network_updater.h | 6 +- nnforge/neural_network_toolset.cpp | 6 +- nnforge/neural_network_toolset.h | 1 + nnforge/plain/network_updater_plain.cpp | 119 ++++++++++++++++-------- nnforge/plain/network_updater_plain.h | 13 ++- 12 files changed, 203 insertions(+), 74 deletions(-) diff --git a/nnforge/cuda/network_updater_cuda.cu b/nnforge/cuda/network_updater_cuda.cu index 98b903c..8f55d8d 100644 --- a/nnforge/cuda/network_updater_cuda.cu +++ b/nnforge/cuda/network_updater_cuda.cu @@ -94,6 +94,31 @@ namespace nnforge } } + __global__ void apply_gradient_with_momentum_kernel( + float * __restrict data, + float * __restrict gradient, + float * __restrict previous_upd, + const float * __restrict learning_rate, + float normalizer, + float weight_decay, + float momentum, + int elem_count) + { + int elem_id = blockDim.x * (blockIdx.y * gridDim.x + blockIdx.x) + threadIdx.x; + if (elem_id < elem_count) + { + float current_weight = data[elem_id]; + float lr = learning_rate[elem_id]; + float gr = gradient[elem_id]; + float prev_upd = previous_upd[elem_id]; + float upd = prev_upd * momentum + lr * (gr * normalizer - current_weight * weight_decay); + float new_weight = current_weight + upd; + data[elem_id] = new_weight; + gradient[elem_id] = 0.0F; + previous_upd[elem_id] = upd; + } + } + unsigned int network_updater_cuda::max_entry_count_in_single_batch = 1024; network_updater_cuda::network_updater_cuda( @@ -155,7 +180,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay) + float weight_decay, + float momentum) { testing_result_smart_ptr res(new testing_result(ef)); @@ -193,6 +219,9 @@ namespace nnforge std::vector > net_data = get_data(data); std::vector > learning_rate_data = get_learning_rate(learning_rate); std::vector > gradient = get_zero_gradient(net_data); + std::vector > previous_upd; + if (momentum > 0.0F) + previous_upd = get_zero_gradient(net_data); { buffer_cuda_size_configuration buffers_config; @@ -217,6 +246,9 @@ namespace nnforge for(std::vector >::const_iterator it = gradient.begin(); it != gradient.end(); ++it) for(std::vector::const_iterator it2 = it->begin(); it2 != it->end(); ++it2) buffers_config.add_constant_buffer((*it2)->get_size()); + for(std::vector >::const_iterator it = previous_upd.begin(); it != previous_upd.end(); ++it) + for(std::vector::const_iterator it2 = it->begin(); it2 != it->end(); ++it2) + buffers_config.add_constant_buffer((*it2)->get_size()); unsigned int max_entry_count = std::min(std::min(cuda_config->get_max_entry_count(buffers_config), reader.get_entry_count()), max_entry_count_in_single_batch); if (entry_read_count_list.empty() || (max_entry_count >= batch_size)) @@ -557,9 +589,11 @@ namespace nnforge *command_stream, net_data, gradient, + previous_upd, learning_rate_data, gradient_normalizer, - weight_decay); + weight_decay, + momentum); entry_gradient_calculated_count = 0; } @@ -600,9 +634,11 @@ namespace nnforge *command_stream, net_data, gradient, + previous_upd, learning_rate_data, gradient_normalizer, - weight_decay); + weight_decay, + momentum); entry_gradient_calculated_count = 0; } @@ -757,29 +793,62 @@ namespace nnforge cudaStream_t stream_id, std::vector >& data, std::vector >& gradient, + std::vector >& prev_upd, std::vector >& learning_rate, float gradient_normalizer, - float weight_decay) + float weight_decay, + float momentum) { - std::vector >::iterator gradient_it = gradient.begin(); - std::vector >::iterator learning_rate_it = learning_rate.begin(); - for(std::vector >::iterator data_it = data.begin(); data_it != data.end(); ++data_it, ++gradient_it, ++learning_rate_it) + if (momentum> 0.0F) { - std::vector::iterator gradient_it2 = gradient_it->begin(); - std::vector::iterator learning_rate_it2 = learning_rate_it->begin(); - for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++learning_rate_it2) + std::vector >::iterator gradient_it = gradient.begin(); + std::vector >::iterator prev_upd_it = prev_upd.begin(); + std::vector >::iterator learning_rate_it = learning_rate.begin(); + for(std::vector >::iterator data_it = data.begin(); data_it != data.end(); ++data_it, ++gradient_it, ++prev_upd_it, ++learning_rate_it) { - int elem_count = (*data_it2)->get_size() / sizeof(float); - std::pair kernel_dims = cuda_util::get_grid_and_threadblock_sizes_sequential_access( - *cuda_config, - elem_count); - apply_gradient_kernel<<>>( - **data_it2, - **gradient_it2, - **learning_rate_it2, - gradient_normalizer, - weight_decay, - elem_count); + std::vector::iterator gradient_it2 = gradient_it->begin(); + std::vector::iterator prev_upd_it2 = prev_upd_it->begin(); + std::vector::iterator learning_rate_it2 = learning_rate_it->begin(); + for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++prev_upd_it2, ++learning_rate_it2) + { + int elem_count = (*data_it2)->get_size() / sizeof(float); + std::pair kernel_dims = cuda_util::get_grid_and_threadblock_sizes_sequential_access( + *cuda_config, + elem_count); + apply_gradient_with_momentum_kernel<<>>( + **data_it2, + **gradient_it2, + **prev_upd_it2, + **learning_rate_it2, + gradient_normalizer, + weight_decay, + momentum, + elem_count); + } + } + } + else + { + std::vector >::iterator gradient_it = gradient.begin(); + std::vector >::iterator learning_rate_it = learning_rate.begin(); + for(std::vector >::iterator data_it = data.begin(); data_it != data.end(); ++data_it, ++gradient_it, ++learning_rate_it) + { + std::vector::iterator gradient_it2 = gradient_it->begin(); + std::vector::iterator learning_rate_it2 = learning_rate_it->begin(); + for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++learning_rate_it2) + { + int elem_count = (*data_it2)->get_size() / sizeof(float); + std::pair kernel_dims = cuda_util::get_grid_and_threadblock_sizes_sequential_access( + *cuda_config, + elem_count); + apply_gradient_kernel<<>>( + **data_it2, + **gradient_it2, + **learning_rate_it2, + gradient_normalizer, + weight_decay, + elem_count); + } } } } diff --git a/nnforge/cuda/network_updater_cuda.h b/nnforge/cuda/network_updater_cuda.h index b1baaa9..d834bf9 100644 --- a/nnforge/cuda/network_updater_cuda.h +++ b/nnforge/cuda/network_updater_cuda.h @@ -46,7 +46,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay); + float weight_decay, + float momentum); // The method is called when client calls set_input_configuration_specific and the convolution specific configuration is modified. // The layer_config_list is guaranteed to be compatible with schema @@ -88,9 +89,11 @@ namespace nnforge cudaStream_t stream_id, std::vector >& data, std::vector >& gradient, + std::vector >& prev_upd, std::vector >& learning_rate, float gradient_normalizer, - float weight_decay); + float weight_decay, + float momentum); cuda_running_configuration_const_smart_ptr cuda_config; diff --git a/nnforge/network_trainer.cpp b/nnforge/network_trainer.cpp index 738c7e3..ae84396 100644 --- a/nnforge/network_trainer.cpp +++ b/nnforge/network_trainer.cpp @@ -28,6 +28,8 @@ namespace nnforge , learning_rate_decay_tail_epoch_count(0) , learning_rate_decay_rate(0.5F) , learning_rate(0.02F) + , batch_size(1) + , momentum(0.0F) { } diff --git a/nnforge/network_trainer.h b/nnforge/network_trainer.h index aac23f6..3648969 100644 --- a/nnforge/network_trainer.h +++ b/nnforge/network_trainer.h @@ -47,6 +47,7 @@ namespace nnforge unsigned int learning_rate_rise_head_epoch_count; float learning_rate_rise_rate; float weight_decay; + float momentum; protected: network_trainer(network_schema_smart_ptr schema); diff --git a/nnforge/network_trainer_sdlm.cpp b/nnforge/network_trainer_sdlm.cpp index 6848455..ce879ec 100644 --- a/nnforge/network_trainer_sdlm.cpp +++ b/nnforge/network_trainer_sdlm.cpp @@ -72,7 +72,8 @@ namespace nnforge learning_rate, task.data, batch_size, - weight_decay); + weight_decay, + momentum); boost::chrono::duration sec = (boost::chrono::high_resolution_clock::now() - start); diff --git a/nnforge/network_trainer_sgd.cpp b/nnforge/network_trainer_sgd.cpp index b858538..4671d20 100644 --- a/nnforge/network_trainer_sgd.cpp +++ b/nnforge/network_trainer_sgd.cpp @@ -50,7 +50,8 @@ namespace nnforge lr_and_comment.first, task.data, batch_size, - weight_decay); + weight_decay, + momentum); boost::chrono::duration sec = (boost::chrono::high_resolution_clock::now() - start); diff --git a/nnforge/network_updater.cpp b/nnforge/network_updater.cpp index 1a1592f..edc5ef2 100644 --- a/nnforge/network_updater.cpp +++ b/nnforge/network_updater.cpp @@ -79,7 +79,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay) + float weight_decay, + float momentum) { // Check data-schema consistency data->check_network_data_consistency(*schema); @@ -96,7 +97,7 @@ namespace nnforge data->apply_dropout_layer_config(layer_id_to_dropout_config_map, false); - testing_result_smart_ptr res = actual_update(reader, learning_rate, data, batch_size, weight_decay); + testing_result_smart_ptr res = actual_update(reader, learning_rate, data, batch_size, weight_decay, momentum); data->apply_dropout_layer_config(layer_id_to_dropout_config_map, true); diff --git a/nnforge/network_updater.h b/nnforge/network_updater.h index 9a802ab..31122fe 100644 --- a/nnforge/network_updater.h +++ b/nnforge/network_updater.h @@ -43,7 +43,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay); + float weight_decay, + float momentum); // set_input_configuration_specific should be called prior to this method call for this method to succeed float get_flops_for_single_entry() const; @@ -60,7 +61,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay) = 0; + float weight_decay, + float momentum) = 0; // The method is called when client calls set_input_configuration_specific and the convolution specific configuration is modified. // The layer_config_list is guaranteed to be compatible with schema diff --git a/nnforge/neural_network_toolset.cpp b/nnforge/neural_network_toolset.cpp index 848e833..1a15f11 100644 --- a/nnforge/neural_network_toolset.cpp +++ b/nnforge/neural_network_toolset.cpp @@ -209,6 +209,7 @@ namespace nnforge ("epoch_count_in_training_set", boost::program_options::value(&epoch_count_in_training_set)->default_value(1), "The whole should be split in this amount of epochs.") ("weight_decay", boost::program_options::value(&weight_decay)->default_value(0.0F), "Weight decay.") ("batch_size,B", boost::program_options::value(&batch_size)->default_value(1), "Training mini-batch size.") + ("momentum,M", boost::program_options::value(&momentum)->default_value(0.0F), "Momentum in training.") ; { @@ -369,6 +370,7 @@ namespace nnforge std::cout << "epoch_count_in_training_set" << "=" << epoch_count_in_training_set << std::endl; std::cout << "weight_decay" << "=" << weight_decay << std::endl; std::cout << "batch_size" << "=" << batch_size << std::endl; + std::cout << "momentum" << "=" << momentum << std::endl; } { std::vector additional_string_options = get_string_options(); @@ -490,6 +492,7 @@ namespace nnforge res->learning_rate_rise_rate = learning_rate_rise_rate; res->weight_decay = weight_decay; res->batch_size = batch_size; + res->momentum = momentum; return res; } @@ -1456,7 +1459,8 @@ namespace nnforge learning_rates, data, batch_size, - weight_decay); + weight_decay, + momentum); boost::chrono::duration sec = boost::chrono::high_resolution_clock::now() - start; /* { diff --git a/nnforge/neural_network_toolset.h b/nnforge/neural_network_toolset.h index 8440d50..73a22f7 100644 --- a/nnforge/neural_network_toolset.h +++ b/nnforge/neural_network_toolset.h @@ -187,6 +187,7 @@ namespace nnforge float weight_decay; unsigned int snapshot_scale; unsigned int batch_size; + float momentum; protected: std::vector run_batch( diff --git a/nnforge/plain/network_updater_plain.cpp b/nnforge/plain/network_updater_plain.cpp index 21b0a56..a1e6bee 100644 --- a/nnforge/plain/network_updater_plain.cpp +++ b/nnforge/plain/network_updater_plain.cpp @@ -83,7 +83,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay) + float weight_decay, + float momentum) { testing_result_smart_ptr res(new testing_result(ef)); @@ -118,6 +119,12 @@ namespace nnforge network_data_smart_ptr gradient(new network_data(*schema)); gradient->fill(0.0F); + network_data_smart_ptr previous_upd; + if (momentum > 0.0F) + { + previous_upd = network_data_smart_ptr(new network_data(*schema)); + previous_upd->fill(0.0F); + } { buffer_plain_size_configuration buffers_config; @@ -448,17 +455,14 @@ namespace nnforge if (entry_gradient_calculated_count >= batch_size) { float gradient_normalizer = 1.0F / static_cast(std::max(batch_size, entry_gradient_calculated_count)); - layer_data_list::iterator gradient_it = gradient->begin() + testing_layer_count; - layer_data_list::const_iterator learning_rate_it = learning_rate->begin() + testing_layer_count; - for(layer_data_list::iterator data_it = data->begin() + testing_layer_count; data_it != data->end(); ++data_it, ++gradient_it, ++learning_rate_it) - { - apply_gradient( - *data_it, - *gradient_it, - *learning_rate_it, - gradient_normalizer, - weight_decay); - } + apply_gradient( + *data, + *gradient, + *previous_upd, + *learning_rate, + gradient_normalizer, + weight_decay, + momentum); entry_gradient_calculated_count = 0; } } @@ -467,17 +471,14 @@ namespace nnforge if (entry_gradient_calculated_count > 0) { float gradient_normalizer = 1.0F / static_cast(std::max(batch_size, entry_gradient_calculated_count)); - layer_data_list::iterator gradient_it = gradient->begin() + testing_layer_count; - layer_data_list::const_iterator learning_rate_it = learning_rate->begin() + testing_layer_count; - for(layer_data_list::iterator data_it = data->begin() + testing_layer_count; data_it != data->end(); ++data_it, ++gradient_it, ++learning_rate_it) - { - apply_gradient( - *data_it, - *gradient_it, - *learning_rate_it, - gradient_normalizer, - weight_decay); - } + apply_gradient( + *data, + *gradient, + *previous_upd, + *learning_rate, + gradient_normalizer, + weight_decay, + momentum); entry_gradient_calculated_count = 0; } @@ -489,26 +490,66 @@ namespace nnforge } void network_updater_plain::apply_gradient( - layer_data_smart_ptr data, - layer_data_smart_ptr gradient, - const_layer_data_smart_ptr learning_rate, + std::vector& data, + std::vector& gradient, + std::vector& previous_upd, + const std::vector& learning_rate, float normalizer, - float weight_decay) const + float weight_decay, + float momentum) const { - layer_data::iterator gradient_it = gradient->begin(); - layer_data::const_iterator learning_rate_it = learning_rate->begin(); - for(layer_data::iterator data_it = data->begin(); data_it != data->end(); ++data_it, ++gradient_it, ++learning_rate_it) + if (momentum > 0.0F) { - std::vector::iterator gradient_it2 = gradient_it->begin(); - std::vector::const_iterator learning_rate_it2 = learning_rate_it->begin(); - for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++learning_rate_it2) + layer_data_list::iterator gradient_it0 = gradient.begin() + testing_layer_count; + layer_data_list::iterator previous_upd_it0 = previous_upd.begin() + testing_layer_count; + layer_data_list::const_iterator learning_rate_it0 = learning_rate.begin() + testing_layer_count; + for(layer_data_list::iterator data_it0 = data.begin() + testing_layer_count; data_it0 != data.end(); ++data_it0, ++gradient_it0, ++previous_upd_it0, ++learning_rate_it0) { - float current_weight = *data_it2; - float lr = *learning_rate_it2; - float gr = *gradient_it2; - float new_weight = current_weight + lr * (gr * normalizer - current_weight * weight_decay); - *data_it2 = new_weight; - *gradient_it2 = 0.0F; + layer_data::iterator gradient_it = (*gradient_it0)->begin(); + layer_data::iterator previous_upd_it = (*previous_upd_it0)->begin(); + layer_data::const_iterator learning_rate_it = (*learning_rate_it0)->begin(); + for(layer_data::iterator data_it = (*data_it0)->begin(); data_it != (*data_it0)->end(); ++data_it, ++gradient_it, ++previous_upd_it, ++learning_rate_it) + { + std::vector::iterator gradient_it2 = gradient_it->begin(); + std::vector::iterator previous_upd_it2 = previous_upd_it->begin(); + std::vector::const_iterator learning_rate_it2 = learning_rate_it->begin(); + for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++previous_upd_it2, ++learning_rate_it2) + { + float current_weight = *data_it2; + float lr = *learning_rate_it2; + float gr = *gradient_it2; + float prev_upd = *previous_upd_it2; + float upd = prev_upd * momentum + lr * (gr * normalizer - current_weight * weight_decay); + float new_weight = current_weight + upd; + *data_it2 = new_weight; + *gradient_it2 = 0.0F; + *previous_upd_it2 = upd; + } + } + } + } + else + { + layer_data_list::iterator gradient_it0 = gradient.begin() + testing_layer_count; + layer_data_list::const_iterator learning_rate_it0 = learning_rate.begin() + testing_layer_count; + for(layer_data_list::iterator data_it0 = data.begin() + testing_layer_count; data_it0 != data.end(); ++data_it0, ++gradient_it0, ++learning_rate_it0) + { + layer_data::iterator gradient_it = (*gradient_it0)->begin(); + layer_data::const_iterator learning_rate_it = (*learning_rate_it0)->begin(); + for(layer_data::iterator data_it = (*data_it0)->begin(); data_it != (*data_it0)->end(); ++data_it, ++gradient_it, ++learning_rate_it) + { + std::vector::iterator gradient_it2 = gradient_it->begin(); + std::vector::const_iterator learning_rate_it2 = learning_rate_it->begin(); + for(std::vector::iterator data_it2 = data_it->begin(); data_it2 != data_it->end(); ++data_it2, ++gradient_it2, ++learning_rate_it2) + { + float current_weight = *data_it2; + float lr = *learning_rate_it2; + float gr = *gradient_it2; + float new_weight = current_weight + lr * (gr * normalizer - current_weight * weight_decay); + *data_it2 = new_weight; + *gradient_it2 = 0.0F; + } + } } } } diff --git a/nnforge/plain/network_updater_plain.h b/nnforge/plain/network_updater_plain.h index 2de06fd..97d7dbe 100644 --- a/nnforge/plain/network_updater_plain.h +++ b/nnforge/plain/network_updater_plain.h @@ -45,7 +45,8 @@ namespace nnforge network_data_const_smart_ptr learning_rate, network_data_smart_ptr data, unsigned int batch_size, - float weight_decay); + float weight_decay, + float momentum); // The method is called when client calls set_input_configuration_specific and the convolution specific configuration is modified. // The layer_config_list is guaranteed to be compatible with schema @@ -69,11 +70,13 @@ namespace nnforge const unsigned int offset_in_random_list) const; void apply_gradient( - layer_data_smart_ptr data, - layer_data_smart_ptr gradient, - const_layer_data_smart_ptr learning_rate, + std::vector& data, + std::vector& gradient, + std::vector& previous_upd, + const std::vector& learning_rate, float normalizer, - float weight_decay) const; + float weight_decay, + float momentum) const; plain_running_configuration_const_smart_ptr plain_config;