diff --git a/.clang-format b/.clang-format index 4787be27..4f8ee013 100644 --- a/.clang-format +++ b/.clang-format @@ -35,6 +35,7 @@ PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PointerAlignment: Right +SortIncludes: false SpaceBeforeParens: ControlStatements SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: false diff --git a/example/demo.cc b/example/demo.cc index 981a6027..826b07e5 100644 --- a/example/demo.cc +++ b/example/demo.cc @@ -68,7 +68,7 @@ int main(int argc, char *argv[]) { // IMX219 Node imx = b.add("image_io_imx219") .set_params(Param("index", i), - Param("url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/pedestrian.png")); + Param("url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/pedestrian.png")); // ISP Node downscale = b.add("image_processing_bayer_downscale_uint16") @@ -205,7 +205,6 @@ int main(int argc, char *argv[]) { d435["output_d"].bind(depth_buf); sgm["output"].bind(sgm_buf); - { // Execution execution b.run(); @@ -220,7 +219,7 @@ int main(int argc, char *argv[]) { // cv::imwrite("demo-depth.png", depth_img); cv::imwrite("demo-sgm.png", sgm_img); cv::imwrite("demo-yolo.png", yolo_img); - std::cout<<"Passed"< HWC n = b.add("dnn_object_detection")(n["output"]); n = b.add("base_denormalize_3d_uint8")(n["output"]); diff --git a/example/gender_count.cc b/example/gender_count.cc index 3fcdebb7..33316712 100644 --- a/example/gender_count.cc +++ b/example/gender_count.cc @@ -38,7 +38,7 @@ int main(int argc, char *argv[]) { out_p1.bind(out1); Halide::Buffer out2 = Halide::Buffer::make_scalar(); out_p2.bind(out2); - for (int i=0; i<100; ++i) { + for (int i = 0; i < 100; ++i) { b.run(); } } catch (const std::exception &e) { diff --git a/example/imx219_isp_display.cc b/example/imx219_isp_display.cc index 4bd7c877..b71b6ece 100644 --- a/example/imx219_isp_display.cc +++ b/example/imx219_isp_display.cc @@ -28,7 +28,7 @@ int main(int argc, char *argv[]) { constexpr int32_t num = 2; Node n; Node ns[num]; - for (int i=0; i in_buf(3, width, height); - for (int y=0; y r = Halide::Buffer::make_scalar(); - in.bind(in_buf); n["output"].bind(r); - for (int i=0; i<1000; ++i) { + for (int i = 0; i < 1000; ++i) { b.run(); } diff --git a/example/u3v.cc b/example/u3v.cc index 98d700c4..a77b3808 100644 --- a/example/u3v.cc +++ b/example/u3v.cc @@ -48,20 +48,20 @@ struct rawHeader { int pfnc_pixelformat; }; -void display_and_save(int32_t width, int32_t height, std::string directory_path, rawHeader header_info, bool last_run){ +void display_and_save(int32_t width, int32_t height, std::string directory_path, rawHeader header_info, bool last_run) { Builder b; b.set_target(Halide::get_host_target()); - Port dispose_camera{ "dispose_camera", Halide::type_of() }; - Port dispose_writer{ "dispose_writer", Halide::type_of() }; + Port dispose_camera{"dispose_camera", Halide::type_of()}; + Port dispose_writer{"dispose_writer", Halide::type_of()}; - Port gain0_p{ "gain0", Halide::type_of() }; - Port gain1_p{ "gain1", Halide::type_of() }; - Port exposure0_p{ "exposure0", Halide::type_of() }; - Port exposure1_p{ "exposure1", Halide::type_of() }; - Port wp{ "width", Halide::type_of() }; - Port hp{ "height", Halide::type_of() }; + Port gain0_p{"gain0", Halide::type_of()}; + Port gain1_p{"gain1", Halide::type_of()}; + Port exposure0_p{"exposure0", Halide::type_of()}; + Port exposure1_p{"exposure1", Halide::type_of()}; + Port wp{"width", Halide::type_of()}; + Port hp{"height", Halide::type_of()}; Port r_gain0_p{"r_gain0", Halide::type_of()}; Port g_gain0_p{"g_gain0", Halide::type_of()}; @@ -71,82 +71,50 @@ void display_and_save(int32_t width, int32_t height, std::string directory_path, Port g_gain1_p{"g_gain1", Halide::type_of()}; Port b_gain1_p{"b_gain1", Halide::type_of()}; - // obtain sensor images auto n = b.add("image_io_u3v_camera2_u16x2")(dispose_camera, gain0_p, gain1_p, exposure0_p, exposure1_p) - .set_params( - Param{"pixel_format_ptr", PIXEL_FORMAT}, - Param{"frame_sync", "true"}, - Param{"gain_key", FEATURE_GAIN_KEY}, - Param{"exposure_key", FEATURE_EXPOSURE_KEY} - ); + .set_params( + Param{"pixel_format_ptr", PIXEL_FORMAT}, + Param{"frame_sync", "true"}, + Param{"gain_key", FEATURE_GAIN_KEY}, + Param{"exposure_key", FEATURE_EXPOSURE_KEY}); Port lp = n["output0"]; Port rp = n["output1"]; Port fcp = n["frame_count"]; - n = b.add("image_io_binarysaver")(rp, lp, fcp, dispose_writer, wp, hp).set_params( - Param{"output_directory", directory_path}, - Param{"fps", "60.0"}); + n = b.add("image_io_binarysaver")(rp, lp, fcp, dispose_writer, wp, hp).set_params(Param{"output_directory", directory_path}, Param{"fps", "60.0"}); Port terminator = n["output"]; /* image processing on the iamge obtained from the left sensor */ n = b.add("image_processing_normalize_raw_image")(lp).set_params(Param{"bit_width", "12"}, Param{"bit_shift", "0"}); n = b.add("image_processing_bayer_white_balance")(r_gain0_p, g_gain0_p, b_gain0_p, n["output"]).set_params(Param{"bayer_pattern", "GBRG"}); - n = b.add("image_processing_bayer_demosaic_simple")(n["output"]).set_params( - Param{"bayer_pattern", "GBRG"}, - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)} - ); - n = b.add("image_processing_resize_bilinear_3d")(n["output"]).set_params( - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)}, - Param{"scale", std::to_string(2.0f)} - ); + n = b.add("image_processing_bayer_demosaic_simple")(n["output"]).set_params(Param{"bayer_pattern", "GBRG"}, Param{"width", std::to_string(width)}, Param{"height", std::to_string(height)}); + n = b.add("image_processing_resize_bilinear_3d")(n["output"]).set_params(Param{"width", std::to_string(width)}, Param{"height", std::to_string(height)}, Param{"scale", std::to_string(2.0f)}); n = b.add("base_denormalize_3d_uint8")(n["output"]); - n = b.add("image_processing_crop_image_3d_uint8")(n["output"]).set_params( - Param{"input_width", std::to_string(width)}, - Param{"input_height", std::to_string(height)}, - Param{"output_width", std::to_string(width)}, - Param{"output_height", std::to_string(height)} - ); /*optional*/ + n = b.add("image_processing_crop_image_3d_uint8")(n["output"]).set_params(Param{"input_width", std::to_string(width)}, Param{"input_height", std::to_string(height)}, Param{"output_width", std::to_string(width)}, Param{"output_height", std::to_string(height)}); /*optional*/ lp = n["output"]; /* image processing on the iamge obtained from the right sensor */ n = b.add("image_processing_normalize_raw_image")(rp).set_params(Param{"bit_width", "12"}, Param{"bit_shift", "0"}); n = b.add("image_processing_bayer_white_balance")(r_gain1_p, g_gain1_p, b_gain1_p, n["output"]).set_params(Param{"bayer_pattern", "GBRG"}); - n = b.add("image_processing_bayer_demosaic_simple")(n["output"]).set_params( - Param{"bayer_pattern", "GBRG"}, - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)} - ); - n = b.add("image_processing_resize_bilinear_3d")(n["output"]).set_params( - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)}, - Param{"scale", std::to_string(2.0f)} - ); + n = b.add("image_processing_bayer_demosaic_simple")(n["output"]).set_params(Param{"bayer_pattern", "GBRG"}, Param{"width", std::to_string(width)}, Param{"height", std::to_string(height)}); + n = b.add("image_processing_resize_bilinear_3d")(n["output"]).set_params(Param{"width", std::to_string(width)}, Param{"height", std::to_string(height)}, Param{"scale", std::to_string(2.0f)}); n = b.add("base_denormalize_3d_uint8")(n["output"]); - n = b.add("image_processing_crop_image_3d_uint8")(n["output"]).set_params( - Param{"input_width", std::to_string(width)}, - Param{"input_height", std::to_string(height)}, - Param{"output_width", std::to_string(width)}, - Param{"output_height", std::to_string(height)} - ); /*optional*/ + n = b.add("image_processing_crop_image_3d_uint8")(n["output"]).set_params(Param{"input_width", std::to_string(width)}, Param{"input_height", std::to_string(height)}, Param{"output_width", std::to_string(width)}, Param{"output_height", std::to_string(height)}); /*optional*/ rp = n["output"]; // display images n = b.add("image_io_gui_display")(lp).set_params( - Param{"idx", "0"}, - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)} - ); + Param{"idx", "0"}, + Param{"width", std::to_string(width)}, + Param{"height", std::to_string(height)}); Port display_output0_p = n["output"]; n = b.add("image_io_gui_display")(rp).set_params( - Param{"idx", "1"}, - Param{"width", std::to_string(width)}, - Param{"height", std::to_string(height)}); + Param{"idx", "1"}, + Param{"width", std::to_string(width)}, + Param{"height", std::to_string(height)}); Port display_output1_p = n["output"]; - /* input */ wp.bind(&width); hp.bind(&height); @@ -175,14 +143,14 @@ void display_and_save(int32_t width, int32_t height, std::string directory_path, int loop_num = 400; - for (int i=0; i< loop_num; ++i) { + for (int i = 0; i < loop_num; ++i) { b.run(); } cv::destroyAllWindows(); } -void open_and_check(int32_t& width, int32_t& height, const std::filesystem::path output_directory, uint32_t& file_idx, std::ifstream& ifs, bool *finished) { +void open_and_check(int32_t &width, int32_t &height, const std::filesystem::path output_directory, uint32_t &file_idx, std::ifstream &ifs, bool *finished) { auto file_path = output_directory / ("raw-" + ::std::to_string(file_idx++) + ".bin"); ifs = ::std::ifstream(file_path, ::std::ios::binary); @@ -192,9 +160,9 @@ void open_and_check(int32_t& width, int32_t& height, const std::filesystem::path } int32_t version = 0; - ifs.read(reinterpret_cast(&version), sizeof(int32_t)); - ifs.read(reinterpret_cast(&width), sizeof(int32_t)); - ifs.read(reinterpret_cast(&height), sizeof(int32_t)); + ifs.read(reinterpret_cast(&version), sizeof(int32_t)); + ifs.read(reinterpret_cast(&width), sizeof(int32_t)); + ifs.read(reinterpret_cast(&height), sizeof(int32_t)); ifs = ::std::ifstream(file_path, ::std::ios::binary); @@ -202,16 +170,17 @@ void open_and_check(int32_t& width, int32_t& height, const std::filesystem::path ifs.seekg(512, ::std::ios_base::beg); } -bool load_header_file(std::filesystem::path output_directory, rawHeader header_info) -{ +bool load_header_file(std::filesystem::path output_directory, rawHeader header_info) { std::ifstream ifs; - int width_, height_; + int width_, height_; bool finished_ = false; uint32_t file_idx_ = 0; // first look open_and_check(width_, height_, output_directory, file_idx_, ifs, &finished_); - if (finished_) { return false; } + if (finished_) { + return false; + } bool ret = true; ret = ret && width_ == header_info.width_; @@ -224,11 +193,11 @@ bool load_header_file(std::filesystem::path output_directory, rawHeader header_i /* there's no audio recording feature yet so offset != 0 */ uint32_t offset_frame_count; - const size_t size = static_cast(width_ *height_*sizeof(uint16_t)); + const size_t size = static_cast(width_ * height_ * sizeof(uint16_t)); // first frame count ifs.read(reinterpret_cast(&offset_frame_count), sizeof(offset_frame_count)); - ifs.seekg(2*size, std::ios::cur); + ifs.seekg(2 * size, std::ios::cur); uint32_t frame_index = offset_frame_count; ofs << "offset_frame_count: " << offset_frame_count << "\n"; @@ -236,12 +205,13 @@ bool load_header_file(std::filesystem::path output_directory, rawHeader header_i uint32_t frame_count = frame_index; ofs << frame_index++ << " : " << frame_count << "\n"; - uint skip_count = 0; + uint skip_count = 0; while (!finished_) { ifs.read(reinterpret_cast(&frame_count), sizeof(frame_count)); - while( frame_index < frame_count ){ - ofs << frame_index++ << " : x" << "\n"; + while (frame_index < frame_count) { + ofs << frame_index++ << " : x" + << "\n"; ++skip_count; } ofs << frame_index++ << " : " << frame_count << "\n"; @@ -257,8 +227,8 @@ bool load_header_file(std::filesystem::path output_directory, rawHeader header_i } uint total_frame = frame_count - offset_frame_count; - std::cout << (total_frame-skip_count)*1.0 / total_frame << std::endl; - ofs << (total_frame-skip_count)*1.0 / total_frame << "\n"; + std::cout << (total_frame - skip_count) * 1.0 / total_frame << std::endl; + ofs << (total_frame - skip_count) * 1.0 / total_frame << "\n"; ofs.close(); ifs.close(); @@ -273,7 +243,7 @@ int main() { std::filesystem::path test_directory = "u3v_framerate_test"; std::string output_directory_prefix = "u3v_framerate_test"; - if(! std::filesystem::is_directory(test_directory)){ + if (!std::filesystem::is_directory(test_directory)) { bool ret = std::filesystem::create_directory(test_directory); } @@ -284,16 +254,16 @@ int main() { int num_run = 50; - for (int i = 0; i < num_run; ++i){ + for (int i = 0; i < num_run; ++i) { std::filesystem::path output_directory = test_directory / (output_directory_prefix + std::to_string(i)); - if(! std::filesystem::is_directory(output_directory)){ + if (!std::filesystem::is_directory(output_directory)) { bool ret = std::filesystem::create_directory(output_directory); } display_and_save(width, height, output_directory.string(), header_info, i == num_run - 1); bool ret = load_header_file(output_directory, header_info); - if (!ret){ - std::runtime_error("header info is incorrect at test " + std::to_string(i) ); + if (!ret) { + std::runtime_error("header info is incorrect at test " + std::to_string(i)); } } return 0; diff --git a/example/u3v_camera1_opencv/u3v_camera1_opencv.cc b/example/u3v_camera1_opencv/u3v_camera1_opencv.cc index 7e5e0ad4..b690caa0 100644 --- a/example/u3v_camera1_opencv/u3v_camera1_opencv.cc +++ b/example/u3v_camera1_opencv/u3v_camera1_opencv.cc @@ -29,12 +29,11 @@ int positive_pow(int base, int expo) { if (expo == 1) { return base; } else { - return base * positive_pow(base, expo-1); + return base * positive_pow(base, expo - 1); } } -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { try { // Define builders to build, compile, and execute pipelines. // Build the pipeline by adding nodes to the Builder. @@ -48,19 +47,18 @@ int main(int argc, char *argv[]) // Connect the input port to the Node instance created by b.add(). Node n = b.add("image_io_u3v_cameraN_u16x2")(&gain, &exposure) - .set_params( - Param("num_devices", 1), - Param("frame_sync", false), - Param("gain_key", FEATURE_GAIN_KEY), - Param("exposure_key", FEATURE_EXPOSURE_KEY), - Param("realtime_display_mode", true), - Param("enable_control", true) - ); + .set_params( + Param("num_devices", 1), + Param("frame_sync", false), + Param("gain_key", FEATURE_GAIN_KEY), + Param("exposure_key", FEATURE_EXPOSURE_KEY), + Param("realtime_display_mode", true), + Param("enable_control", true)); // Map output buffer and ports by using Port::bind. // - output: output of the obtained video data // - frame_count: output of the frame number of the obtained video - std::vector< int > buf_size = std::vector < int >{ width, height }; + std::vector buf_size = std::vector{width, height}; Buffer output(buf_size); Buffer frame_count(1); @@ -69,9 +67,8 @@ int main(int argc, char *argv[]) // Obtain image data continuously for 100 frames to facilitate operation check. int loop_num = 100; - int coef = positive_pow(2, NUM_BIT_SHIFT); - for (int i = 0; i < loop_num; ++i) - { + int coef = positive_pow(2, NUM_BIT_SHIFT); + for (int i = 0; i < loop_num; ++i) { // JIT compilation and execution of pipelines with Builder. b.run(); @@ -89,10 +86,10 @@ int main(int argc, char *argv[]) cv::waitKey(1); } - } catch (const ion::Error& e) { + } catch (const ion::Error &e) { std::cerr << e.what() << std::endl; return 1; } - return 0; + return 0; } diff --git a/example/u3v_camera2_opencv/u3v_camera2_opencv.cc b/example/u3v_camera2_opencv/u3v_camera2_opencv.cc index 2fa36426..7cc5373c 100644 --- a/example/u3v_camera2_opencv/u3v_camera2_opencv.cc +++ b/example/u3v_camera2_opencv/u3v_camera2_opencv.cc @@ -29,13 +29,11 @@ int positive_pow(int base, int expo) { if (expo == 1) { return base; } else { - return base * positive_pow(base, expo-1); + return base * positive_pow(base, expo - 1); } } - -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { try { // Define builders to build, compile, and execute pipelines. // Build the pipeline by adding nodes to the Builder. @@ -49,19 +47,18 @@ int main(int argc, char *argv[]) // Connect the input port to the Node instance created by b.add(). Node n = b.add("image_io_u3v_cameraN_u16x2")(&gain, &exposure, &gain, &exposure) - .set_params( - Param("frame_sync", false), - Param("gain_key", FEATURE_GAIN_KEY), - Param("exposure_key", FEATURE_EXPOSURE_KEY), - Param("realtime_display_mode", true), - Param("enable_control", true) - ); + .set_params( + Param("frame_sync", false), + Param("gain_key", FEATURE_GAIN_KEY), + Param("exposure_key", FEATURE_EXPOSURE_KEY), + Param("realtime_display_mode", true), + Param("enable_control", true)); // Map output buffer and ports by using Port::bind. // - output0: output 0 of the obtained video data // - output1: output 1 of the obtained video data // - frame_count: output of the frame number of the obtained video - std::vector< int > buf_size = std::vector < int >{ width, height }; + std::vector buf_size = std::vector{width, height}; Buffer output0(buf_size); Buffer output1(buf_size); Buffer frame_count(1); @@ -71,9 +68,8 @@ int main(int argc, char *argv[]) // Obtain image data continuously for 100 frames to facilitate operation check. int loop_num = 100; - int coef = positive_pow(2, NUM_BIT_SHIFT); - for (int i = 0; i < loop_num; ++i) - { + int coef = positive_pow(2, NUM_BIT_SHIFT); + for (int i = 0; i < loop_num; ++i) { // JIT compilation and execution of pipelines with Builder. b.run(); @@ -103,10 +99,10 @@ int main(int argc, char *argv[]) cv::waitKey(1); } - } catch (const ion::Error& e) { + } catch (const ion::Error &e) { std::cerr << e.what() << std::endl; return 1; } - return 0; + return 0; } diff --git a/example/u3v_fake.cc b/example/u3v_fake.cc index 490593b7..60974c65 100644 --- a/example/u3v_fake.cc +++ b/example/u3v_fake.cc @@ -11,8 +11,7 @@ using namespace ion; // Before you run this script please ensure you `export GENICAM_FILENAME=` or`set GENICAM_FILENAME=` // original arv-fake-camera.xml can be download at https://github.com/Sensing-Dev/aravis/blob/main/src/arv-fake-camera.xml // you can also create your fake-camera.xml by editing original xml file and `export GENICAM_FILENAME= -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { try { // Define builders to build, compile, and execute pipelines. // Build the pipeline by adding nodes to the Builder. @@ -28,67 +27,63 @@ int main(int argc, char *argv[]) int num_device = 2; // if you don't set width and height, default width is 640 and default height is 480 Node n = b.add("image_io_u3v_cameraN_u8x2")().set_params( - Param("num_devices", num_device), - Param("pixel_format", "Mono8")); - -/******************** force simulation mode*************************/ -// int width = 960; -// int height = 640; -// int num_device = 2; -// Node n = b.add("image_io_u3v_cameraN_u8x2")().set_params( -// Param("num_devices", num_device), -// Param("force_sim_mode", true), -// Param("width", width), -// Param("height", height)); - -/********************RGB 8*************************/ -// Node n = b.add("image_io_u3v_cameraN_u8x3")().set_params( -// Param("num_devices", num_device), -// Param("pixel_format", "RGB8")); - -/********************Mono16*************************/ -// Node n = b.add("image_io_u3v_cameraN_u16x2")().set_params( -// Param("num_devices", num_device), -// Param("pixel_format", "Mono16")); - - - - std::vector< int > buf_size = std::vector < int >{ width, height}; + Param("num_devices", num_device), + Param("pixel_format", "Mono8")); + + /******************** force simulation mode*************************/ + // int width = 960; + // int height = 640; + // int num_device = 2; + // Node n = b.add("image_io_u3v_cameraN_u8x2")().set_params( + // Param("num_devices", num_device), + // Param("force_sim_mode", true), + // Param("width", width), + // Param("height", height)); + + /********************RGB 8*************************/ + // Node n = b.add("image_io_u3v_cameraN_u8x3")().set_params( + // Param("num_devices", num_device), + // Param("pixel_format", "RGB8")); + + /********************Mono16*************************/ + // Node n = b.add("image_io_u3v_cameraN_u16x2")().set_params( + // Param("num_devices", num_device), + // Param("pixel_format", "Mono16")); + + std::vector buf_size = std::vector{width, height}; std::vector> outputs; std::vector> frame_counts; - for (int i = 0; i < num_device; ++i){ - outputs.push_back(Halide::Buffer(buf_size)); - frame_counts.push_back(Halide::Buffer(1)); + for (int i = 0; i < num_device; ++i) { + outputs.push_back(Halide::Buffer(buf_size)); + frame_counts.push_back(Halide::Buffer(1)); } n["output"].bind(outputs); n["frame_count"].bind(frame_counts); - // Obtain image data int user_input = -1; - while(user_input == -1) - { + while (user_input == -1) { // JIT compilation and execution of pipelines with Builder. - b.run(); - - // Convert the retrieved buffer object to OpenCV buffer format. - // Depends on sensor image pixel format, apply bit shift on images - // Display the image - for (int i = 0;i using Buffer = Halide::Buffer; -} // ion +} // namespace ion -#endif // ION_BUFFER_H +#endif // ION_BUFFER_H diff --git a/include/ion/builder.h b/include/ion/builder.h index 94c6bd94..2199be6f 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -24,7 +24,6 @@ class DynamicModule; */ class Builder { public: - struct Impl; /** @@ -41,33 +40,33 @@ class Builder { * Adding new node to the builder. * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). */ - Node add(const std::string& name); + Node add(const std::string &name); /** * Adding new node to the specific graph. * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). * @arg id: graph unique identifier */ - Node add(const std::string& name, const GraphID& graph_id); + Node add(const std::string &name, const GraphID &graph_id); /** * Adding new node to the graph. * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). */ - Graph add_graph(const std::string& name); + Graph add_graph(const std::string &name); /** * Set the target of the pipeline built with this builder. * @arg target: The target ofject which consists of OS, Architecture, and sets of Features. * See https://halide-lang.org/docs/struct_halide_1_1_target.html for more details. */ - Builder& set_target(const Target& target); + Builder &set_target(const Target &target); /** * Set the user context which will be applied the pipeline built with this builder. * @arg user_context_ptr: The pointer to the user context. */ - Builder& set_jit_context(Halide::JITUserContext *user_context_ptr); + Builder &set_jit_context(Halide::JITUserContext *user_context_ptr); /** * Load bb module dynamically and enable it to compile your pipeline. @@ -75,26 +74,26 @@ class Builder { * @note This API is expected to be used from external process. * This information is not stored in graph definition exported by Builder::save because it is not portable. */ - Builder& with_bb_module(const std::string& path); + Builder &with_bb_module(const std::string &path); /** * Save the pipeline as a file in JSON format. * @arg file_name: The file path to be written. */ - void save(const std::string& file_name); + void save(const std::string &file_name); /** * Load the pipeline from a file which is written by Builder::save. * @arg file_name: The file path to be read. */ - void load(const std::string& file_name); + void load(const std::string &file_name); /** * Compile the pipeline into static library and header. * @arg function_name: The symbol name of the entry point in the static library. * This name is also used as prefix of the static library and header. */ - void compile(const std::string& function_name, const CompileOption& option = CompileOption{}); + void compile(const std::string &function_name, const CompileOption &option = CompileOption{}); /** * Run the pipeline immediately. @@ -111,14 +110,13 @@ class Builder { /** * Retrieve arginfo of specific bb */ - std::vector bb_arginfos(const std::string& name); + std::vector bb_arginfos(const std::string &name); /** * Retrieve metadata of Building Block in json format. */ std::string bb_metadata(void); - /** * Get target */ @@ -127,14 +125,13 @@ class Builder { /** * Get the node list. */ - const std::vector& nodes() const; - std::vector& nodes(); + const std::vector &nodes() const; + std::vector &nodes(); /** * Get registered externs */ - const std::map& jit_externs() const; - + const std::map &jit_externs() const; /** Write out the loop nests specified by the schedule for this * Builder's pipeline. Helpful for understanding what a schedule is @@ -145,7 +142,7 @@ class Builder { * Register disposer hook which will be called from Builder destructor. * This is available only for JIT mode. */ - static void register_disposer(Impl* impl, const std::string& bb_id, const std::string& disposer_symbol); + static void register_disposer(Impl *impl, const std::string &bb_id, const std::string &disposer_symbol); /** * Retrieve impl pointer for lowering @@ -153,10 +150,9 @@ class Builder { const Impl *impl_ptr() const; private: - std::shared_ptr impl_; }; -} // namespace ion +} // namespace ion -#endif // ION_BUILDER_H +#endif // ION_BUILDER_H diff --git a/include/ion/building_block.h b/include/ion/building_block.h index 22a55883..4516bb57 100644 --- a/include/ion/building_block.h +++ b/include/ion/building_block.h @@ -34,26 +34,25 @@ class BuildingBlock : public Halide::Generator { BuildingBlockParam builder_impl_ptr{"builder_impl_ptr", 0}; BuildingBlockParam bb_id{"bb_id", ""}; - protected: - - template - void register_disposer(const std::string& n) { - auto builder_impl(reinterpret_cast(static_cast(builder_impl_ptr))); - if (builder_impl) { - Builder::register_disposer(builder_impl, bb_id, n); - } - } - - ion::Buffer get_id() { - std::string bb_id_s(bb_id); - Buffer buf(static_cast(bb_id_s.size() + 1)); - buf.fill(0); - std::memcpy(buf.data(), bb_id_s.c_str(), bb_id_s.size()); - return buf; - } +protected: + template + void register_disposer(const std::string &n) { + auto builder_impl(reinterpret_cast(static_cast(builder_impl_ptr))); + if (builder_impl) { + Builder::register_disposer(builder_impl, bb_id, n); + } + } + + ion::Buffer get_id() { + std::string bb_id_s(bb_id); + Buffer buf(static_cast(bb_id_s.size() + 1)); + buf.fill(0); + std::memcpy(buf.data(), bb_id_s.c_str(), bb_id_s.size()); + return buf; + } }; -} // namespace ion +} // namespace ion #define ION_REGISTER_BUILDING_BLOCK(...) HALIDE_REGISTER_GENERATOR(__VA_ARGS__) diff --git a/include/ion/c_ion.h b/include/ion/c_ion.h index ec59a654..7be18453 100644 --- a/include/ion/c_ion.h +++ b/include/ion/c_ion.h @@ -9,10 +9,10 @@ extern "C" { #endif typedef enum { - ion_type_int = 0, //!< signed integers - ion_type_uint = 1, //!< unsigned integers - ion_type_float = 2, //!< floating point numbers - ion_type_handle = 3 //!< opaque pointer type (void *) + ion_type_int = 0, //!< signed integers + ion_type_uint = 1, //!< unsigned integers + ion_type_float = 2, //!< floating point numbers + ion_type_handle = 3 //!< opaque pointer type (void *) } ion_type_code_t; typedef struct { @@ -34,19 +34,19 @@ typedef struct ion_port_map_t_ *ion_port_map_t; typedef struct ion_graph_t_ *ion_graph_t; int ion_port_create(ion_port_t *, const char *, ion_type_t, int); -int ion_port_create_with_index(ion_port_t *, ion_port_t , int); +int ion_port_create_with_index(ion_port_t *, ion_port_t, int); int ion_port_destroy(ion_port_t); -int ion_port_bind_i8(ion_port_t, int8_t*); -int ion_port_bind_i16(ion_port_t, int16_t*); -int ion_port_bind_i32(ion_port_t, int32_t*); -int ion_port_bind_i64(ion_port_t, int64_t*); -int ion_port_bind_u1(ion_port_t, bool*); -int ion_port_bind_u8(ion_port_t, uint8_t*); -int ion_port_bind_u16(ion_port_t, uint16_t*); -int ion_port_bind_u32(ion_port_t, uint32_t*); -int ion_port_bind_u64(ion_port_t, uint64_t*); -int ion_port_bind_f32(ion_port_t, float*); -int ion_port_bind_f64(ion_port_t, double*); +int ion_port_bind_i8(ion_port_t, int8_t *); +int ion_port_bind_i16(ion_port_t, int16_t *); +int ion_port_bind_i32(ion_port_t, int32_t *); +int ion_port_bind_i64(ion_port_t, int64_t *); +int ion_port_bind_u1(ion_port_t, bool *); +int ion_port_bind_u8(ion_port_t, uint8_t *); +int ion_port_bind_u16(ion_port_t, uint16_t *); +int ion_port_bind_u32(ion_port_t, uint32_t *); +int ion_port_bind_u64(ion_port_t, uint64_t *); +int ion_port_bind_f32(ion_port_t, float *); +int ion_port_bind_f64(ion_port_t, double *); int ion_port_bind_buffer(ion_port_t, ion_buffer_t); int ion_port_bind_buffer_array(ion_port_t, ion_buffer_t *, int); @@ -79,44 +79,29 @@ int ion_buffer_write(ion_buffer_t, void *, int size); int ion_buffer_read(ion_buffer_t, void *, int size); int ion_graph_create(ion_graph_t *, ion_builder_t, const char *); -int ion_graph_add_node(ion_graph_t, const char*, ion_node_t *); +int ion_graph_add_node(ion_graph_t, const char *, ion_node_t *); int ion_graph_destroy(ion_graph_t); int ion_graph_run(ion_graph_t); -int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* objs, int size); +int ion_graph_create_with_multiple(ion_graph_t *ptr, ion_graph_t *objs, int size); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_create(ion_port_map_t *); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_destroy(ion_port_map_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_i8(ion_port_map_t, ion_port_t, int8_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_i16(ion_port_map_t, ion_port_t, int16_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_i32(ion_port_map_t, ion_port_t, int32_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_i64(ion_port_map_t, ion_port_t, int64_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_u1(ion_port_map_t, ion_port_t, bool); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_u8(ion_port_map_t, ion_port_t, uint8_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_u16(ion_port_map_t, ion_port_t, uint16_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_u32(ion_port_map_t, ion_port_t, uint32_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_u64(ion_port_map_t, ion_port_t, uint64_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_f32(ion_port_map_t, ion_port_t, float); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_f64(ion_port_map_t, ion_port_t, double); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_buffer(ion_port_map_t, ion_port_t, ion_buffer_t); -[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] -int ion_port_map_set_buffer_array(ion_port_map_t, ion_port_t, ion_buffer_t *, int); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_create(ion_port_map_t *); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_destroy(ion_port_map_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i8(ion_port_map_t, ion_port_t, int8_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i16(ion_port_map_t, ion_port_t, int16_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i32(ion_port_map_t, ion_port_t, int32_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i64(ion_port_map_t, ion_port_t, int64_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u1(ion_port_map_t, ion_port_t, bool); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u8(ion_port_map_t, ion_port_t, uint8_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u16(ion_port_map_t, ion_port_t, uint16_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u32(ion_port_map_t, ion_port_t, uint32_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u64(ion_port_map_t, ion_port_t, uint64_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f32(ion_port_map_t, ion_port_t, float); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f64(ion_port_map_t, ion_port_t, double); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer(ion_port_map_t, ion_port_t, ion_buffer_t); +[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer_array(ion_port_map_t, ion_port_t, ion_buffer_t *, int); #if defined __cplusplus } #endif -#endif // ION_C_ION_H +#endif // ION_C_ION_H diff --git a/include/ion/error.h b/include/ion/error.h index dbf8845b..4d972ece 100644 --- a/include/ion/error.h +++ b/include/ion/error.h @@ -7,6 +7,6 @@ namespace ion { using Error = Halide::Error; -} // ion +} // namespace ion -#endif // ION_ERROR_H +#endif // ION_ERROR_H diff --git a/include/ion/export.h b/include/ion/export.h index 41df0dff..92eb8164 100644 --- a/include/ion/export.h +++ b/include/ion/export.h @@ -7,4 +7,4 @@ #define ION_EXPORT __attribute__((visibility("default"))) #endif -#endif // ION_EXPORT_H +#endif // ION_EXPORT_H diff --git a/include/ion/graph.h b/include/ion/graph.h index bb5aa416..937c9659 100644 --- a/include/ion/graph.h +++ b/include/ion/graph.h @@ -11,22 +11,21 @@ class Builder; class Graph { public: - struct Impl; Graph(); - Graph(Builder & builder , const std::string& name = ""); + Graph(Builder &builder, const std::string &name = ""); - Graph& operator+=(const Graph& rhs); + Graph &operator+=(const Graph &rhs); - friend Graph operator+(const Graph& lhs, const Graph& rhs); + friend Graph operator+(const Graph &lhs, const Graph &rhs); /** * Adding new node to the graph. * @arg n: The name of the building block which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). */ - Node add(const std::string& name); + Node add(const std::string &name); /** * Run the pipeline immediately. @@ -36,13 +35,13 @@ class Graph { * Set the user context which will be applied the pipeline built with this graph. * @arg user_context_ptr: The pointer to the user context. */ - Graph& set_jit_context(Halide::JITUserContext *user_context_ptr); + Graph &set_jit_context(Halide::JITUserContext *user_context_ptr); /** * Get the node list. */ - const std::vector& nodes() const; - std::vector& nodes(); + const std::vector &nodes() const; + std::vector &nodes(); bool defined() const { return impl_.get() != nullptr; @@ -52,6 +51,6 @@ class Graph { std::shared_ptr impl_; }; -} // namespace ion +} // namespace ion -#endif // ION_GRAPH_H +#endif // ION_GRAPH_H diff --git a/include/ion/ion.h b/include/ion/ion.h index 5821b0d3..9f5e3f49 100644 --- a/include/ion/ion.h +++ b/include/ion/ion.h @@ -12,4 +12,4 @@ #include "target.h" #include "type.h" -#endif // ION_ION_H +#endif // ION_ION_H diff --git a/include/ion/node.h b/include/ion/node.h index 8ca8a928..e2e79599 100644 --- a/include/ion/node.h +++ b/include/ion/node.h @@ -28,15 +28,19 @@ class Node { std::vector ports; std::vector arginfos; - Impl(): id(), name(), target(), params(), ports() {} - Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_); - Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID &graph_id_); + Impl() + : id(), name(), target(), params(), ports() { + } + Impl(const NodeID &id_, const std::string &name_, const Halide::Target &target_); + Impl(const NodeID &id_, const std::string &name_, const Halide::Target &target_, const GraphID &graph_id_); }; public: - Node() : impl_(new Impl) {}; + Node() + : impl_(new Impl){}; - Node(const std::shared_ptr& impl) : impl_(impl) {}; + Node(const std::shared_ptr &impl) + : impl_(impl){}; /** * Set the target of the node. @@ -45,7 +49,7 @@ class Node { * This target object can be retrieved by calling BuildingBlock::get_target from BuildingBlock::generate and BuildingBlock::schedule. * @return Node object whose target is set. */ - Node set_target(const Halide::Target& target) { + Node set_target(const Halide::Target &target) { impl_->target = target; return *this; } @@ -57,12 +61,12 @@ class Node { * @return Node object whose parameter is set. */ template - Node set_params(Args ...args) { + Node set_params(Args... args) { impl_->params = std::vector{args...}; return *this; } - void set_params(const std::vector& params) { + void set_params(const std::vector ¶ms) { impl_->params = params; } @@ -76,16 +80,16 @@ class Node { * @return Node object whose port is set. */ template - Node operator()(Args ...args) { + Node operator()(Args... args) { set_iports(std::vector{make_iport(args)...}); return *this; } - void set_iports(const std::vector& ports); + void set_iports(const std::vector &ports); void set_iport(Port port); - void set_iport(const std::string& name, Port port); + void set_iport(const std::string &name, Port port); void set_oport(Port port); @@ -94,49 +98,47 @@ class Node { * @arg name: The name of port name which is matched with first argument of Input/Output declared in user-defined class deriving BuildingBlock. * @return Port object which is specified by name. */ - Port operator[](const std::string& name); + Port operator[](const std::string &name); // Getter - const NodeID & id() const { + const NodeID &id() const { return impl_->id; } - const std::string& name() const { + const std::string &name() const { return impl_->name; } - const Halide::Target& target() const { + const Halide::Target &target() const { return impl_->target; } - const std::vector& params() const { + const std::vector ¶ms() const { return impl_->params; } - const std::vector& ports() const { + const std::vector &ports() const { return impl_->ports; } - Port iport(const std::string& pn); + Port iport(const std::string &pn); std::vector> iports() const; - Port oport(const std::string& pn); + Port oport(const std::string &pn); std::vector> oports() const; std::vector> unbound_iports() const; std::vector> unbound_oports() const; - void detect_data_hazard ()const ; + void detect_data_hazard() const; private: - Node(const NodeID& id, const std::string& name, const Halide::Target& target) - : impl_(new Impl{id, name, target}) - { + Node(const NodeID &id, const std::string &name, const Halide::Target &target) + : impl_(new Impl{id, name, target}) { } - Node(const NodeID&& id, const std::string& name, const Halide::Target& target, const GraphID& graph_id) - : impl_(new Impl{id, name, target, graph_id}) - { + Node(const NodeID &&id, const std::string &name, const Halide::Target &target, const GraphID &graph_id) + : impl_(new Impl{id, name, target, graph_id}) { } Port make_iport(Port arg) const { @@ -152,7 +154,7 @@ class Node { } template - Port make_iport(Halide::Buffer& arg) const { + Port make_iport(Halide::Buffer &arg) const { if (to_string(impl_->graph_id).empty()) return Port(arg); else @@ -160,17 +162,16 @@ class Node { } template - Port make_iport(std::vector>& arg) const { + Port make_iport(std::vector> &arg) const { if (to_string(impl_->graph_id).empty()) return Port(arg); else return Port(arg, impl_->graph_id); } - std::shared_ptr impl_; }; -} // namespace ion +} // namespace ion -#endif // ION_NODE_H +#endif // ION_NODE_H diff --git a/include/ion/param.h b/include/ion/param.h index 0dbad5a0..c1642f68 100644 --- a/include/ion/param.h +++ b/include/ion/param.h @@ -10,37 +10,52 @@ namespace ion { * Param class is used to create static parameter for each node. */ class Param { - public: - Param() {} - - /** - * Create static parameter which is passed as GeneratorParam declared in user-defined class deriving BuildingBlock. - * @arg key: Key of the parameter. - * It should be matched with first argument of GeneratorParam declared in user-defined class deriving BuildingBlock. - * @arg val: Value in string. - * It can be string representation which is able to convert through std::istringstream. - */ - Param(const std::string& key, const std::string& val) : key_(key), val_(val) {} - - template::value>::type* = nullptr> - Param(const std::string& key, T val) : key_(key), val_(val ? "true" : "false") {} - - template::value && !std::is_same::value>::type* = nullptr> - Param(const std::string& key, T val) : key_(key), val_(std::to_string(val)) {} - - std::string key() const { return key_; } - std::string& key() { return key_; } - - std::string val() const { return val_; } - std::string& val() { return val_; } - - private: +public: + Param() { + } + + /** + * Create static parameter which is passed as GeneratorParam declared in user-defined class deriving BuildingBlock. + * @arg key: Key of the parameter. + * It should be matched with first argument of GeneratorParam declared in user-defined class deriving BuildingBlock. + * @arg val: Value in string. + * It can be string representation which is able to convert through std::istringstream. + */ + Param(const std::string &key, const std::string &val) + : key_(key), val_(val) { + } + + template::value>::type * = nullptr> + Param(const std::string &key, T val) + : key_(key), val_(val ? "true" : "false") { + } + + template::value && !std::is_same::value>::type * = nullptr> + Param(const std::string &key, T val) + : key_(key), val_(std::to_string(val)) { + } + + std::string key() const { + return key_; + } + std::string &key() { + return key_; + } + + std::string val() const { + return val_; + } + std::string &val() { + return val_; + } + +private: std::string key_; std::string val_; }; -} // namespace ion +} // namespace ion -#endif // ION_PARAM_H +#endif // ION_PARAM_H diff --git a/include/ion/port.h b/include/ion/port.h index 0c0cf687..7206eeee 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -17,9 +17,9 @@ namespace ion { template -std::string unify_name(const std::vector>& bufs) { +std::string unify_name(const std::vector> &bufs) { std::stringstream ss; - for (auto i=0; i>& bufs) { } template -int32_t unify_dimension(const std::vector>& bufs) { +int32_t unify_dimension(const std::vector> &bufs) { int32_t dimension = 0; - for (auto i=0; i params; std::unordered_map instances; - std::unordered_map > bound_address; + std::unordered_map> bound_address; Impl(); - Impl(const NodeID& nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid ); + Impl(const NodeID &nid, const std::string &pn, const Halide::Type &t, int32_t d, const GraphID &gid); }; public: + Port() + : impl_(new Impl(NodeID(""), "", Halide::Type(), 0, GraphID(""))), index_(-1) { + } - Port() : impl_(new Impl(NodeID(""), "", Halide::Type(), 0, GraphID(""))), index_(-1) {} - - Port(const std::shared_ptr& impl, int32_t index) : impl_(impl), index_(index) {} + Port(const std::shared_ptr &impl, int32_t index) + : impl_(impl), index_(index) { + } /** * Construct new port for scalar value. * @arg k: The key of the port which should be matched with BuildingBlock Input/Output name. * @arg t: The type of the value. */ - Port(const std::string& n, Halide::Type t) : impl_(new Impl(NodeID(""), n, t, 0, GraphID(""))), index_(-1) {} + Port(const std::string &n, Halide::Type t) + : impl_(new Impl(NodeID(""), n, t, 0, GraphID(""))), index_(-1) { + } /** * Construct new port for vector value. @@ -89,32 +94,36 @@ class Port { * @arg t: The type of the element value. * @arg d: The dimension of the port. The range is 1 to 4. */ - Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(NodeID(""), n, t, d, GraphID(""))), index_(-1) {} + Port(const std::string &n, Halide::Type t, int32_t d) + : impl_(new Impl(NodeID(""), n, t, d, GraphID(""))), index_(-1) { + } /** * Construct new port from scalar pointer */ template::value>::type* = nullptr> - Port(T *vptr) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, GraphID(""))), index_(-1) { + typename std::enable_if::value>::type * = nullptr> + Port(T *vptr) + : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, GraphID(""))), index_(-1) { this->bind(vptr); } - /** + /** * Construct new port from scalar pointer */ template::value>::type* = nullptr> - Port(T *vptr, const GraphID & gid) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, gid)), index_(-1) { + typename std::enable_if::value>::type * = nullptr> + Port(T *vptr, const GraphID &gid) + : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, gid)), index_(-1) { this->bind(vptr); } - /** * Construct new port from buffer */ template - Port(const Halide::Buffer& buf) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) { + Port(const Halide::Buffer &buf) + : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) { this->bind(buf); } @@ -122,7 +131,8 @@ class Port { * Construct new port from buffer and bind graph id to port */ template - Port(const Halide::Buffer& buf, const GraphID & gid) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) { + Port(const Halide::Buffer &buf, const GraphID &gid) + : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) { this->bind(buf); } @@ -130,164 +140,195 @@ class Port { * Construct new port from array of buffer */ template - Port(const std::vector>& bufs) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), GraphID(""))), index_(-1) { + Port(const std::vector> &bufs) + : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), GraphID(""))), index_(-1) { this->bind(bufs); } - /** + /** * Construct new port from array of buffer and bind graph id to port */ template - Port(const std::vector>& bufs, const GraphID & gid) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), gid)), index_(-1) { + Port(const std::vector> &bufs, const GraphID &gid) + : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), gid)), index_(-1) { this->bind(bufs); } // Getter - const PortID id() const { return impl_->id; } - const Channel& pred_chan() const { return impl_->pred_chan; } - const NodeID& pred_id() const { return std::get<0>(impl_->pred_chan); } - const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); } - const std::set& succ_chans() const { return impl_->succ_chans; } - const Halide::Type& type() const { return impl_->type; } - int32_t dimensions() const { return impl_->dimensions; } - int32_t size() const { return static_cast(impl_->params.size()); } - int32_t index() const { return index_; } - const GraphID& graph_id() const { return impl_->graph_id; } + const PortID id() const { + return impl_->id; + } + const Channel &pred_chan() const { + return impl_->pred_chan; + } + const NodeID &pred_id() const { + return std::get<0>(impl_->pred_chan); + } + const std::string &pred_name() const { + return std::get<1>(impl_->pred_chan); + } + const std::set &succ_chans() const { + return impl_->succ_chans; + } + const Halide::Type &type() const { + return impl_->type; + } + int32_t dimensions() const { + return impl_->dimensions; + } + int32_t size() const { + return static_cast(impl_->params.size()); + } + int32_t index() const { + return index_; + } + const GraphID &graph_id() const { + return impl_->graph_id; + } // Setter - void set_index(int index) { index_ = index; } + void set_index(int index) { + index_ = index; + } // Util - bool has_pred() const { return !std::get<0>(impl_->pred_chan).value().empty(); } - bool has_pred_by_nid(const NodeID & nid) const { return !to_string(std::get<0>(impl_->pred_chan)).empty(); } - bool has_succ() const { return !impl_->succ_chans.empty(); } - bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); } - bool has_succ_by_nid(const NodeID& nid) const { + bool has_pred() const { + return !std::get<0>(impl_->pred_chan).value().empty(); + } + bool has_pred_by_nid(const NodeID &nid) const { + return !to_string(std::get<0>(impl_->pred_chan)).empty(); + } + bool has_succ() const { + return !impl_->succ_chans.empty(); + } + bool has_succ(const Channel &c) const { + return impl_->succ_chans.count(c); + } + bool has_succ_by_nid(const NodeID &nid) const { return std::count_if(impl_->succ_chans.begin(), impl_->succ_chans.end(), - [&](const Port::Channel& c) { return std::get<0>(c) == nid; }); + [&](const Port::Channel &c) { return std::get<0>(c) == nid; }); } - void determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn); + void determine_succ(const NodeID &nid, const std::string &old_pn, const std::string &new_pn); /** * Overloaded operator to set the port index and return a reference to the current port. eg. port[0] */ - Port operator[](int index) { - Port port(*this); - port.index_ = index; - return port; - } - - template - void bind(T *v) { - bind(Halide::type_of(), v); - } - - void bind(Halide::Type target_type, void *v) { - auto i = index_ == -1 ? 0 : index_; - - if (has_pred()) { - impl_->params[i] = Halide::Parameter{target_type, false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())}; - } else { - impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; - } - - impl_->instances[i] = v; - impl_->bound_address[i] = std::make_tuple(v, false); - } - - template - void bind(const Halide::Buffer& buf) { - auto i = index_ == -1 ? 0 : index_; - if (has_pred()) { - impl_->params[i] = Halide::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())}; - } else { - impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())}; - } - - auto raw_buf = buf.raw_buffer(); - impl_->instances[i] = raw_buf; - impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); - } - - template - void bind(const std::vector>& bufs) { - for (int i=0; i(bufs.size()); ++i) { - if (has_pred()) { - impl_->params[i] = Halide::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), id(),pred_name(), i, graph_id())}; - } else { - impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; - } - - auto raw_buf = bufs[i].raw_buffer(); - impl_->instances[i] = raw_buf; - impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); - } - - } - - static std::tuple, bool> find_impl(const std::string& id); - - std::vector as_expr() const { - if (dimensions() != 0) { + Port operator[](int index) { + Port port(*this); + port.index_ = index; + return port; + } + + template + void bind(T *v) { + bind(Halide::type_of(), v); + } + + void bind(Halide::Type target_type, void *v) { + auto i = index_ == -1 ? 0 : index_; + + if (has_pred()) { + impl_->params[i] = Halide::Parameter{target_type, false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } else { + impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } + + impl_->instances[i] = v; + impl_->bound_address[i] = std::make_tuple(v, false); + } + + template + void bind(const Halide::Buffer &buf) { + auto i = index_ == -1 ? 0 : index_; + if (has_pred()) { + impl_->params[i] = Halide::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } else { + impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } + + auto raw_buf = buf.raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); + } + + template + void bind(const std::vector> &bufs) { + for (int i = 0; i < static_cast(bufs.size()); ++i) { + if (has_pred()) { + impl_->params[i] = Halide::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } else { + impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; + } + + auto raw_buf = bufs[i].raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); + } + } + + static std::tuple, bool> find_impl(const std::string &id); + + std::vector as_expr() const { + if (dimensions() != 0) { throw std::runtime_error("Unreachable"); - } - - std::vector es; - for (const auto& [i, param] : impl_->params) { - if (es.size() <= i) { - es.resize(i+1, Halide::Expr()); - } - es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), id(), pred_name(), i, graph_id()), param); - } - return es; - } - - std::vector as_func() const { - std::vector fs; - for (const auto& [i, param] : impl_->params ) { - if (fs.size() <= i) { - fs.resize(i+1, Halide::Func()); - } - std::vector args; - std::vector args_expr; - for (int j = 0; j < dimensions(); ++j) { - args.push_back(Halide::Var::implicit(j)); - args_expr.push_back(Halide::Var::implicit(j)); - } - Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im"); - f(args) = Halide::Internal::Call::make(param, args_expr); - fs[i] = f; - if (std::get<1>(impl_->bound_address[i])) { - f.compute_root(); - } - } - return fs; - } - - std::vector as_argument() const { - std::vector args; - for (const auto& [i, param] : impl_->params) { - if (args.size() <= i) { - args.resize(i+1, Halide::Argument()); - } - auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer; - args[i] = Halide::Argument(argument_name(pred_id(), id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates()); - } - return args; - } - - std::vector as_instance() const { - std::vector instances; - for (const auto& [i, instance] : impl_->instances) { - if (instances.size() <= i) { - instances.resize(i+1, nullptr); - } - instances[i] = instance; } - return instances; - } + + std::vector es; + for (const auto &[i, param] : impl_->params) { + if (es.size() <= i) { + es.resize(i + 1, Halide::Expr()); + } + es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), id(), pred_name(), i, graph_id()), param); + } + return es; + } + + std::vector as_func() const { + std::vector fs; + for (const auto &[i, param] : impl_->params) { + if (fs.size() <= i) { + fs.resize(i + 1, Halide::Func()); + } + std::vector args; + std::vector args_expr; + for (int j = 0; j < dimensions(); ++j) { + args.push_back(Halide::Var::implicit(j)); + args_expr.push_back(Halide::Var::implicit(j)); + } + Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im"); + f(args) = Halide::Internal::Call::make(param, args_expr); + fs[i] = f; + if (std::get<1>(impl_->bound_address[i])) { + f.compute_root(); + } + } + return fs; + } + + std::vector as_argument() const { + std::vector args; + for (const auto &[i, param] : impl_->params) { + if (args.size() <= i) { + args.resize(i + 1, Halide::Argument()); + } + auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer; + args[i] = Halide::Argument(argument_name(pred_id(), id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates()); + } + return args; + } + + std::vector as_instance() const { + std::vector instances; + for (const auto &[i, instance] : impl_->instances) { + if (instances.size() <= i) { + instances.resize(i + 1, nullptr); + } + instances[i] = instance; + } + return instances; + } private: /** @@ -296,16 +337,18 @@ class Port { * pid and pn is stored in both pred and succ, * then it will determined through pipeline build process. */ - Port(const NodeID & nid, const std::string& pn) : impl_(new Impl(nid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) {} + Port(const NodeID &nid, const std::string &pn) + : impl_(new Impl(nid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) { + } - std::shared_ptr impl_; + std::shared_ptr impl_; - // NOTE: - // The reasons why index sits outside of the impl_ is because - // index is tentatively used to hold index of params. - int32_t index_; + // NOTE: + // The reasons why index sits outside of the impl_ is because + // index is tentatively used to hold index of params. + int32_t index_; }; -} // namespace ion +} // namespace ion -#endif // ION_PORT_H +#endif // ION_PORT_H diff --git a/include/ion/target.h b/include/ion/target.h index e97d2138..760e28c0 100644 --- a/include/ion/target.h +++ b/include/ion/target.h @@ -11,6 +11,6 @@ Target get_host_target(); Target get_target_from_environment(); -} // ion +} // namespace ion -#endif // ION_TARGET_H +#endif // ION_TARGET_H diff --git a/include/ion/type.h b/include/ion/type.h index ff0d6ed0..ba102e45 100644 --- a/include/ion/type.h +++ b/include/ion/type.h @@ -12,6 +12,6 @@ Type type_of() { return Halide::type_of(); } -} // ion +} // namespace ion -#endif // ION_TYPE_H +#endif // ION_TYPE_H diff --git a/include/ion/util.h b/include/ion/util.h index d9956037..664a4d9e 100644 --- a/include/ion/util.h +++ b/include/ion/util.h @@ -6,7 +6,7 @@ namespace ion { class Port; -std::string array_name(const std::string& port_name, size_t i); +std::string array_name(const std::string &port_name, size_t i); // a string-like identifier that is typed on a tag type template @@ -14,15 +14,20 @@ struct StringID { using tag_type = Tag; // needs to be default-constructable because of use in map[] below - StringID(std::string s) : _value(std::move(s)) {} - StringID() : _value() {} + StringID(std::string s) + : _value(std::move(s)) { + } + StringID() + : _value() { + } // provide access to the underlying string value - const std::string &value() const { return _value; } + const std::string &value() const { + return _value; + } - struct StringIDHash { - // Use hash of string as hash function. - size_t operator()(const StringID& id) const - { + struct StringIDHash { + // Use hash of string as hash function. + size_t operator()(const StringID &id) const { return std::hash()(id.value()); } }; @@ -40,9 +45,8 @@ struct StringID { } // and let's go ahead and provide expected free functions - friend - auto to_string(const StringID &r) - -> const std::string & { + friend auto to_string(const StringID &r) + -> const std::string & { return r._value; } }; @@ -55,8 +59,8 @@ using NodeID = StringID; using GraphID = StringID; using PortID = StringID; -std::string argument_name(const NodeID& node_id, const PortID & portId, const std::string& name, int32_t index, const GraphID& graph_id); +std::string argument_name(const NodeID &node_id, const PortID &portId, const std::string &name, int32_t index, const GraphID &graph_id); -} // namespace ion +} // namespace ion #endif diff --git a/src/bb/base/bb.h b/src/bb/base/bb.h index 446625fb..618d95db 100644 --- a/src/bb/base/bb.h +++ b/src/bb/base/bb.h @@ -1963,8 +1963,8 @@ class ScalarToFuncFloat : public ScalarToFunc { class Schedule : public ion::BuildingBlock { public: BuildingBlockParam output_name{"output_name", ""}; - BuildingBlockParam compute_level{"compute_level", ""}; // "compute_inline" or "compute_root" - BuildingBlockParam concurrency{"concurrency", ""}; // comma separated string + BuildingBlockParam compute_level{"compute_level", ""}; // "compute_inline" or "compute_root" + BuildingBlockParam concurrency{"concurrency", ""}; // comma separated string Input input{"input"}; Output output{"output"}; @@ -1988,7 +1988,7 @@ class Schedule : public ion::BuildingBlock { } else { Var x = f.args()[0]; Var y = f.args()[1]; - for (int i=2; i { } else { Var x = f.args()[0]; Var y = f.args()[1]; - for (int i=2; i 0 \ - ? std::thread::hardware_concurrency() - 1 \ - : 0)) +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) #endif /* @@ -94,11 +92,11 @@ #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif //_CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE -#endif //_CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE #if defined(_MSC_VER) #ifdef _WIN64 @@ -110,19 +108,19 @@ using ssize_t = int; #if _MSC_VER < 1900 #define snprintf _snprintf_s #endif -#endif // _MSC_VER +#endif // _MSC_VER #ifndef S_ISREG #define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) -#endif // S_ISREG +#endif // S_ISREG #ifndef S_ISDIR #define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) -#endif // S_ISDIR +#endif // S_ISDIR #ifndef NOMINMAX #define NOMINMAX -#endif // NOMINMAX +#endif // NOMINMAX #include #include @@ -142,14 +140,14 @@ using ssize_t = int; #ifndef strcasecmp #define strcasecmp _stricmp -#endif // strcasecmp +#endif // strcasecmp using socket_t = SOCKET; #ifdef CPPHTTPLIB_USE_POLL #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif -#else // not _WIN32 +#else // not _WIN32 #include #include @@ -171,7 +169,7 @@ using socket_t = SOCKET; using socket_t = int; #define INVALID_SOCKET (-1) -#endif //_WIN32 +#endif //_WIN32 #include #include @@ -216,7 +214,7 @@ using socket_t = int; #if OPENSSL_VERSION_NUMBER < 0x10100000L #include inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { - return M_ASN1_STRING_data(asn1); + return M_ASN1_STRING_data(asn1); } #endif #endif @@ -238,14 +236,14 @@ namespace httplib { namespace detail { struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), s2.begin(), s2.end(), - [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; -} // namespace detail +} // namespace detail using Headers = std::multimap; @@ -258,44 +256,48 @@ struct Response; using ResponseHandler = std::function; struct MultipartFormData { - std::string name; - std::string content; - std::string filename; - std::string content_type; + std::string name; + std::string content; + std::string filename; + std::string content_type; }; using MultipartFormDataItems = std::vector; using MultipartFormDataMap = std::multimap; class DataSink { public: - DataSink() : os(&sb_), sb_(*this) {} + DataSink() + : os(&sb_), sb_(*this) { + } - DataSink(const DataSink &) = delete; - DataSink &operator=(const DataSink &) = delete; - DataSink(DataSink &&) = delete; - DataSink &operator=(DataSink &&) = delete; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; - std::function write; - std::function done; - std::function is_writable; - std::ostream os; + std::function write; + std::function done; + std::function is_writable; + std::ostream os; private: - class data_sink_streambuf : public std::streambuf { - public: - explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + class data_sink_streambuf : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) + : sink_(sink) { + } - protected: - std::streamsize xsputn(const char *s, std::streamsize n) { - sink_.write(s, static_cast(n)); - return n; - } + protected: + std::streamsize xsputn(const char *s, std::streamsize n) { + sink_.write(s, static_cast(n)); + return n; + } - private: - DataSink &sink_; - }; + private: + DataSink &sink_; + }; - data_sink_streambuf sb_; + data_sink_streambuf sb_; }; using ContentProvider = @@ -312,223 +314,231 @@ using MultipartContentHeader = class ContentReader { public: - using Reader = std::function; - using MultipartReader = std::function; + using Reader = std::function; + using MultipartReader = std::function; - ContentReader(Reader reader, MultipartReader multipart_reader) - : reader_(reader), multipart_reader_(multipart_reader) {} + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(reader), multipart_reader_(multipart_reader) { + } - bool operator()(MultipartContentHeader header, - ContentReceiver receiver) const { - return multipart_reader_(header, receiver); - } + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(header, receiver); + } - bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + bool operator()(ContentReceiver receiver) const { + return reader_(receiver); + } - Reader reader_; - MultipartReader multipart_reader_; + Reader reader_; + MultipartReader multipart_reader_; }; using Range = std::pair; using Ranges = std::vector; struct Request { - std::string method; - std::string path; - Headers headers; - std::string body; - - std::string remote_addr; - int remote_port = -1; - - // for server - std::string version; - std::string target; - Params params; - MultipartFormDataMap files; - Ranges ranges; - Match matches; - - // for client - size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; - ResponseHandler response_handler; - ContentReceiver content_receiver; - size_t content_length = 0; - ContentProvider content_provider; - Progress progress; + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + size_t content_length = 0; + ContentProvider content_provider; + Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl; + const SSL *ssl; #endif - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - bool has_param(const char *key) const; - std::string get_param_value(const char *key, size_t id = 0) const; - size_t get_param_value_count(const char *key) const; + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; - bool is_multipart_form_data() const; + bool is_multipart_form_data() const; - bool has_file(const char *key) const; - MultipartFormData get_file_value(const char *key) const; + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; - // private members... - size_t authorization_count_ = 0; + // private members... + size_t authorization_count_ = 0; }; struct Response { - std::string version; - int status = -1; - std::string reason; - Headers headers; - std::string body; - - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); - - void set_redirect(const char *url, int status = 302); - void set_redirect(const std::string &url, int status = 302); - void set_content(const char *s, size_t n, const char *content_type); - void set_content(std::string s, const char *content_type); - - void set_content_provider( - size_t length, const char *content_type, ContentProvider provider, - const std::function &resource_releaser = nullptr); - - void set_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - void set_chunked_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - Response() = default; - Response(const Response &) = default; - Response &operator=(const Response &) = default; - Response(Response &&) = default; - Response &operator=(Response &&) = default; - ~Response() { - if (content_provider_resource_releaser_) { - content_provider_resource_releaser_(); - } - } - - // private members... - size_t content_length_ = 0; - ContentProvider content_provider_; - std::function content_provider_resource_releaser_; - bool is_chunked_content_provider = false; + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_redirect(const std::string &url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, const char *content_type, ContentProvider provider, + const std::function &resource_releaser = nullptr); + + void set_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + void set_chunked_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + std::function content_provider_resource_releaser_; + bool is_chunked_content_provider = false; }; class Stream { public: - virtual ~Stream() = default; + virtual ~Stream() = default; - virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; - virtual ssize_t read(char *ptr, size_t size) = 0; - virtual ssize_t write(const char *ptr, size_t size) = 0; - virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; - template - ssize_t write_format(const char *fmt, const Args &... args); - ssize_t write(const char *ptr); - ssize_t write(const std::string &s); + template + ssize_t write_format(const char *fmt, const Args &...args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); }; class TaskQueue { public: - TaskQueue() = default; - virtual ~TaskQueue() = default; + TaskQueue() = default; + virtual ~TaskQueue() = default; - virtual void enqueue(std::function fn) = 0; - virtual void shutdown() = 0; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; - virtual void on_idle(){}; + virtual void on_idle(){}; }; class ThreadPool : public TaskQueue { public: - explicit ThreadPool(size_t n) : shutdown_(false) { - while (n) { - threads_.emplace_back(worker(*this)); - n--; + explicit ThreadPool(size_t n) + : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } } - } - - ThreadPool(const ThreadPool &) = delete; - ~ThreadPool() override = default; - void enqueue(std::function fn) override { - std::unique_lock lock(mutex_); - jobs_.push_back(fn); - cond_.notify_one(); - } + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); } - cond_.notify_all(); + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); - // Join... - for (auto &t : threads_) { - t.join(); + // Join... + for (auto &t : threads_) { + t.join(); + } } - } private: - struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + struct worker { + explicit worker(ThreadPool &pool) + : pool_(pool) { + } - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } - assert(true == static_cast(fn)); - fn(); - } - } + assert(true == static_cast(fn)); + fn(); + } + } - ThreadPool &pool_; - }; - friend struct worker; + ThreadPool &pool_; + }; + friend struct worker; - std::vector threads_; - std::list> jobs_; + std::vector threads_; + std::list> jobs_; - bool shutdown_; + bool shutdown_; - std::condition_variable cond_; - std::mutex mutex_; + std::condition_variable cond_; + std::mutex mutex_; }; using Logger = std::function; @@ -536,670 +546,687 @@ using Logger = std::function; using SocketOptions = std::function; inline void default_socket_options(socket_t sock) { - int yes = 1; + int yes = 1; #ifdef _WIN32 - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); #else #ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); #else - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); #endif #endif } class Server { public: - using Handler = std::function; - using HandlerWithContentReader = std::function; - using Expect100ContinueHandler = - std::function; + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server &Get(const char *pattern, Handler handler); - Server &Post(const char *pattern, Handler handler); - Server &Post(const char *pattern, HandlerWithContentReader handler); - Server &Put(const char *pattern, Handler handler); - Server &Put(const char *pattern, HandlerWithContentReader handler); - Server &Patch(const char *pattern, Handler handler); - Server &Patch(const char *pattern, HandlerWithContentReader handler); - Server &Delete(const char *pattern, Handler handler); - Server &Delete(const char *pattern, HandlerWithContentReader handler); - Server &Options(const char *pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); - bool set_base_dir(const char *dir, const char *mount_point = nullptr); - bool set_mount_point(const char *mount_point, const char *dir); - bool remove_mount_point(const char *mount_point); - void set_file_extension_and_mimetype_mapping(const char *ext, - const char *mime); - void set_file_request_handler(Handler handler); + bool set_base_dir(const char *dir, const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); - void set_error_handler(Handler handler); - void set_expect_100_continue_handler(Expect100ContinueHandler handler); - void set_logger(Logger logger); + void set_error_handler(Handler handler); + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + void set_logger(Logger logger); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_keep_alive_max_count(size_t count); - void set_keep_alive_timeout(time_t sec); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); - void set_idle_interval(time_t sec, time_t usec = 0); + void set_keep_alive_max_count(size_t count); + void set_keep_alive_timeout(time_t sec); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + void set_idle_interval(time_t sec, time_t usec = 0); - void set_payload_max_length(size_t length); + void set_payload_max_length(size_t length); - bool bind_to_port(const char *host, int port, int socket_flags = 0); - int bind_to_any_port(const char *host, int socket_flags = 0); - bool listen_after_bind(); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - bool listen(const char *host, int port, int socket_flags = 0); + bool listen(const char *host, int port, int socket_flags = 0); - bool is_running() const; - void stop(); + bool is_running() const; + void stop(); - std::function new_task_queue; + std::function new_task_queue; protected: - bool process_request(Stream &strm, bool close_connection, - bool &connection_closed, - const std::function &setup_request); - - std::atomic svr_sock_; - size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; - time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; - time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; - size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool process_request(Stream &strm, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; private: - using Handlers = std::vector>; - using HandlersForContentReader = - std::vector>; - - socket_t create_server_socket(const char *host, int port, int socket_flags, - SocketOptions socket_options) const; - int bind_internal(const char *host, int port, int socket_flags); - bool listen_internal(); - - bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(Request &req, Response &res, bool head = false); - bool dispatch_request(Request &req, Response &res, const Handlers &handlers); - bool - dispatch_request_for_content_reader(Request &req, Response &res, - ContentReader content_reader, - const HandlersForContentReader &handlers); - - bool parse_request_line(const char *s, Request &req); - bool write_response(Stream &strm, bool close_connection, const Request &req, - Response &res); - bool write_content_with_provider(Stream &strm, const Request &req, - Response &res, const std::string &boundary, - const std::string &content_type); - bool read_content(Stream &strm, Request &req, Response &res); - bool - read_content_with_content_receiver(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader multipart_header, - ContentReceiver multipart_receiver); - bool read_content_core(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader mulitpart_header, - ContentReceiver multipart_receiver); - - virtual bool process_and_close_socket(socket_t sock); - - std::atomic is_running_; - std::vector> base_dirs_; - std::map file_extension_and_mimetype_map_; - Handler file_request_handler_; - Handlers get_handlers_; - Handlers post_handlers_; - HandlersForContentReader post_handlers_for_content_reader_; - Handlers put_handlers_; - HandlersForContentReader put_handlers_for_content_reader_; - Handlers patch_handlers_; - HandlersForContentReader patch_handlers_for_content_reader_; - Handlers delete_handlers_; - HandlersForContentReader delete_handlers_for_content_reader_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - Expect100ContinueHandler expect_100_continue_handler_; - - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = default_socket_options; + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers); + bool + dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + const HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool close_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; + + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = default_socket_options; }; enum Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification }; class Result { public: - Result(const std::shared_ptr &res, Error err) - : res_(res), err_(err) {} - operator bool() const { return res_ != nullptr; } - bool operator==(std::nullptr_t) const { return res_ == nullptr; } - bool operator!=(std::nullptr_t) const { return res_ != nullptr; } - const Response &value() const { return *res_; } - const Response &operator*() const { return *res_; } - const Response *operator->() const { return res_.get(); } - Error error() const { return err_; } + Result(const std::shared_ptr &res, Error err) + : res_(res), err_(err) { + } + operator bool() const { + return res_ != nullptr; + } + bool operator==(std::nullptr_t) const { + return res_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return res_ != nullptr; + } + const Response &value() const { + return *res_; + } + const Response &operator*() const { + return *res_; + } + const Response *operator->() const { + return res_.get(); + } + Error error() const { + return err_; + } private: - std::shared_ptr res_; - Error err_; + std::shared_ptr res_; + Error err_; }; class ClientImpl { public: - explicit ClientImpl(const std::string &host); - - explicit ClientImpl(const std::string &host, int port); - - explicit ClientImpl(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - virtual ~ClientImpl(); - - virtual bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - - Result Patch(const char *path, const std::string &body, + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + Result Put(const char *path); + Result Put(const char *path, const std::string &body, + const char *content_type); + Result Put(const char *path, const Headers &headers, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); protected: - struct Socket { - socket_t sock = INVALID_SOCKET; + struct Socket { + socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; + SSL *ssl = nullptr; #endif - bool is_open() const { return sock != INVALID_SOCKET; } - }; + bool is_open() const { + return sock != INVALID_SOCKET; + } + }; - virtual bool create_and_connect_socket(Socket &socket); - virtual void close_socket(Socket &socket, bool process_socket_ret); + virtual bool create_and_connect_socket(Socket &socket); + virtual void close_socket(Socket &socket, bool process_socket_ret); - bool process_request(Stream &strm, const Request &req, Response &res, - bool close_connection); + bool process_request(Stream &strm, const Request &req, Response &res, + bool close_connection); - Error get_last_error() const; + Error get_last_error() const; - // Error state - mutable Error error_ = Error::Success; + // Error state + mutable Error error_ = Error::Success; - // Socket endoint information - const std::string host_; - const int port_; - const std::string host_and_port_; + // Socket endoint information + const std::string host_; + const int port_; + const std::string host_and_port_; - // Current open socket - Socket socket_; - mutable std::mutex socket_mutex_; - std::recursive_mutex request_mutex_; + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; - // Default headers - Headers default_headers_; + // Default headers + Headers default_headers_; - // Settings - std::string client_cert_path_; - std::string client_key_path_; + // Settings + std::string client_cert_path_; + std::string client_key_path_; - time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; - time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - std::string basic_auth_username_; - std::string basic_auth_password_; - std::string bearer_token_auth_token_; + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; + std::string digest_auth_username_; + std::string digest_auth_password_; #endif - bool keep_alive_ = false; - bool follow_location_ = false; + bool keep_alive_ = false; + bool follow_location_ = false; - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = nullptr; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = nullptr; - bool compress_ = false; - bool decompress_ = true; + bool compress_ = false; + bool decompress_ = true; - std::string interface_; + std::string interface_; - std::string proxy_host_; - int proxy_port_ = -1; + std::string proxy_host_; + int proxy_port_ = -1; - std::string proxy_basic_auth_username_; - std::string proxy_basic_auth_password_; - std::string proxy_bearer_token_auth_token_; + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; #endif - Logger logger_; - - void copy_settings(const ClientImpl &rhs) { - client_cert_path_ = rhs.client_cert_path_; - client_key_path_ = rhs.client_key_path_; - connection_timeout_sec_ = rhs.connection_timeout_sec_; - read_timeout_sec_ = rhs.read_timeout_sec_; - read_timeout_usec_ = rhs.read_timeout_usec_; - write_timeout_sec_ = rhs.write_timeout_sec_; - write_timeout_usec_ = rhs.write_timeout_usec_; - basic_auth_username_ = rhs.basic_auth_username_; - basic_auth_password_ = rhs.basic_auth_password_; - bearer_token_auth_token_ = rhs.bearer_token_auth_token_; + Logger logger_; + + void copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; #endif - keep_alive_ = rhs.keep_alive_; - follow_location_ = rhs.follow_location_; - tcp_nodelay_ = rhs.tcp_nodelay_; - socket_options_ = rhs.socket_options_; - compress_ = rhs.compress_; - decompress_ = rhs.decompress_; - interface_ = rhs.interface_; - proxy_host_ = rhs.proxy_host_; - proxy_port_ = rhs.proxy_port_; - proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; - proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; - proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + tcp_nodelay_ = rhs.tcp_nodelay_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; #endif - logger_ = rhs.logger_; - } + logger_ = rhs.logger_; + } private: - socket_t create_client_socket() const; - bool read_response_line(Stream &strm, Response &res); - bool write_request(Stream &strm, const Request &req, bool close_connection); - bool redirect(const Request &req, Response &res); - bool handle_request(Stream &strm, const Request &req, Response &res, - bool close_connection); - void stop_core(); - std::shared_ptr send_with_content_provider( - const char *method, const char *path, const Headers &headers, - const std::string &body, size_t content_length, - ContentProvider content_provider, const char *content_type); - - virtual bool process_socket(Socket &socket, - std::function callback); - virtual bool is_ssl() const; + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool close_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool close_connection); + void stop_core(); + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_socket(Socket &socket, + std::function callback); + virtual bool is_ssl() const; }; class Client { public: - // Universal interface - explicit Client(const char *scheme_host_port); - - explicit Client(const char *scheme_host_port, - const std::string &client_cert_path, - const std::string &client_key_path); - - // HTTP only interface - explicit Client(const std::string &host, int port); - - explicit Client(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - ~Client(); - - bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Patch(const char *path, const std::string &body, + // Universal interface + explicit Client(const char *scheme_host_port); + + explicit Client(const char *scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + ~Client(); + + bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const char *path); + Result Put(const char *path, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); - // SSL + // SSL #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Client &set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + Client &set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - Client &set_ca_cert_store(X509_STORE *ca_cert_store); + Client &set_ca_cert_store(X509_STORE *ca_cert_store); - Client &enable_server_certificate_verification(bool enabled); + Client &enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; #endif private: - std::shared_ptr cli_; + std::shared_ptr cli_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool is_ssl_ = false; + bool is_ssl_ = false; #endif -}; // namespace httplib +}; // namespace httplib #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLServer : public Server { public: - SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path = nullptr, - const char *client_ca_cert_dir_path = nullptr); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); - ~SSLServer() override; + ~SSLServer() override; - bool is_valid() const override; + bool is_valid() const override; private: - bool process_and_close_socket(socket_t sock) override; + bool process_and_close_socket(socket_t sock) override; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public ClientImpl { public: - explicit SSLClient(const std::string &host); + explicit SSLClient(const std::string &host); - explicit SSLClient(const std::string &host, int port); + explicit SSLClient(const std::string &host, int port); - explicit SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key); + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key); - ~SSLClient() override; + ~SSLClient() override; - bool is_valid() const override; + bool is_valid() const override; - void set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + void set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(X509_STORE *ca_cert_store); - void enable_server_certificate_verification(bool enabled); + void enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; private: - bool create_and_connect_socket(Socket &socket) override; - void close_socket(Socket &socket, bool process_socket_ret) override; + bool create_and_connect_socket(Socket &socket) override; + void close_socket(Socket &socket, bool process_socket_ret) override; - bool process_socket(Socket &socket, - std::function callback) override; - bool is_ssl() const override; + bool process_socket(Socket &socket, + std::function callback) override; + bool is_ssl() const override; - bool connect_with_proxy(Socket &sock, Response &res, bool &success); - bool initialize_ssl(Socket &socket); + bool connect_with_proxy(Socket &sock, Response &res, bool &success); + bool initialize_ssl(Socket &socket); - bool load_certs(); + bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; - std::once_flag initialize_cert_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; - std::vector host_components_; + std::vector host_components_; - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - X509_STORE *ca_cert_store_ = nullptr; - bool server_certificate_verification_ = true; - long verify_result_ = 0; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + X509_STORE *ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; + long verify_result_ = 0; - friend class ClientImpl; + friend class ClientImpl; }; #endif @@ -1212,724 +1239,779 @@ class SSLClient : public ClientImpl { namespace detail { inline bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; } inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { - if (i >= s.size()) { return false; } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { return false; } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } } - } - return true; + return true; } inline std::string from_i_to_hex(size_t n) { - const char *charset = "0123456789abcdef"; - std::string ret; - do { - ret = charset[n & 15] + ret; - n >>= 4; - } while (n > 0); - return ret; + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; } inline bool start_with(const std::string &a, const std::string &b) { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { - if (std::tolower(a[i]) != std::tolower(b[i])) { return false; } - } - return true; + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (std::tolower(a[i]) != std::tolower(b[i])) { + return false; + } + } + return true; } inline size_t to_utf8(int code, char *buff) { - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); - buff[1] = static_cast(0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED return 0; - } else if (code < 0x10000) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); - buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); - buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[3] = static_cast(0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; } // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c inline std::string base64_encode(const std::string &in) { - static const auto lookup = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(in.size()); + std::string out; + out.reserve(in.size()); - int val = 0; - int valb = -6; + int val = 0; + int valb = -6; - for (auto c : in) { - val = (val << 8) + static_cast(c); - valb += 8; - while (valb >= 0) { - out.push_back(lookup[(val >> valb) & 0x3F]); - valb -= 6; + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } } - } - if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } - while (out.size() % 4) { - out.push_back('='); - } + while (out.size() % 4) { + out.push_back('='); + } - return out; + return out; } inline bool is_file(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); } inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); } inline bool is_valid_path(const std::string &path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } + size_t level = 0; + size_t i = 0; - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; } - auto len = i - beg; - assert(len > 0); + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { return false; } - level--; - } else { - level++; - } + auto len = i - beg; + assert(len > 0); - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } } - } - return true; + return true; } inline std::string encode_url(const std::string &s) { - std::string result; - - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "%20"; break; - case '+': result += "%2B"; break; - case '\r': result += "%0D"; break; - case '\n': result += "%0A"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - // case ':': result += "%3A"; break; // ok? probably... - case ';': result += "%3B"; break; - default: - auto c = static_cast(s[i]); - if (c >= 0x80) { - result += '%'; - char hex[4]; - auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); - assert(len == 2); - result.append(hex, static_cast(len)); - } else { - result += s[i]; - } - break; + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } } - } - return result; + return result; } inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { result.append(buff, len); } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += static_cast(val); - i += 2; // '00' + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; } else { - result += s[i]; + result += s[i]; } - } - } else if (convert_plus_to_space && s[i] == '+') { - result += ' '; - } else { - result += s[i]; } - } - return result; + return result; } inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); } inline std::string file_extension(const std::string &path) { - std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, re)) { return m[1].str(); } - return std::string(); + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); } -inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +inline bool is_space_or_tab(char c) { + return c == ' ' || c == '\t'; +} inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { - while (b + left < e && is_space_or_tab(b[left])) { - left++; - } - while (right > 0 && is_space_or_tab(b[right - 1])) { - right--; - } - return std::make_pair(left, right); + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); } inline std::string trim_copy(const std::string &s) { - auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); - return s.substr(r.first, r.second - r.first); + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); } -template void split(const char *b, const char *e, char d, Fn fn) { - size_t i = 0; - size_t beg = 0; +template +void split(const char *b, const char *e, char d, Fn fn) { + size_t i = 0; + size_t beg = 0; - while (e ? (b + i < e) : (b[i] != '\0')) { - if (b[i] == d) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - beg = i + 1; + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + } + i++; } - i++; - } - if (i) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - } + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) - : strm_(strm), fixed_buffer_(fixed_buffer), - fixed_buffer_size_(fixed_buffer_size) {} + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) { + } - const char *ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } } - } - size_t size() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_used_size_; - } else { - return glowable_buffer_.size(); + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } } - } - bool end_with_crlf() const { - auto end = ptr() + size(); - return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; - } + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } - for (size_t i = 0;; i++) { - char byte; - auto n = strm_.read(&byte, 1); + append(byte); - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; + if (byte == '\n') { + break; + } } - } - append(byte); - - if (byte == '\n') { break; } + return true; } - return true; - } - private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } - } - - Stream &strm_; - char *fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; }; inline int close_socket(socket_t sock) { #ifdef _WIN32 - return closesocket(sock); + return closesocket(sock); #else - return close(sock); + return close(sock); #endif } -template inline ssize_t handle_EINTR(T fn) { - ssize_t res = false; - while (true) { - res = fn(); - if (res < 0 && errno == EINTR) { continue; } - break; - } - return res; +template +inline ssize_t handle_EINTR(T fn) { + ssize_t res = false; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + }); #endif } inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLOUT; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); + }); #endif } inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN | POLLOUT; - - auto timeout = static_cast(sec * 1000 + usec / 1000); - - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); - - if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { - int error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - return res >= 0 && !error; - } - return false; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + return res >= 0 && !error; + } + return false; #else - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - int error = 0; - socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len) >= 0 && - !error; - } - return false; + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); + }); + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; #endif } class SocketStream : public Stream { public: - SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec); - ~SocketStream() override; + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); + ~SocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: - SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec); - ~SSLSocketStream() override; + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #endif class BufferStream : public Stream { public: - BufferStream() = default; - ~BufferStream() override = default; + BufferStream() = default; + ~BufferStream() override = default; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; - const std::string &get_buffer() const; + const std::string &get_buffer() const; private: - std::string buffer; - size_t position = 0; + std::string buffer; + size_t position = 0; }; inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { - using namespace std::chrono; - auto start = steady_clock::now(); - while (true) { - auto val = select_read(sock, 0, 10000); - if (val < 0) { - return false; - } else if (val == 0) { - auto current = steady_clock::now(); - auto duration = duration_cast(current - start); - auto timeout = keep_alive_timeout_sec * 1000; - if (duration.count() > timeout) { return false; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } else { - return true; + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto duration = duration_cast(current - start); + auto timeout = keep_alive_timeout_sec * 1000; + if (duration.count() > timeout) { + return false; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; + } } - } } -template +template inline bool process_server_socket_core(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { - assert(keep_alive_max_count > 0); - auto ret = false; - auto count = keep_alive_max_count; - while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { - auto close_connection = count == 1; - auto connection_closed = false; - ret = callback(close_connection, connection_closed); - if (!ret || connection_closed) { break; } - count--; - } - return ret; -} - -template + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template inline bool process_server_socket(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } -template +template socket_t create_socket(const char *host, int port, int socket_flags, bool tcp_nodelay, SocketOptions socket_options, BindOrConnect bind_or_connect) { - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (getaddrinfo(host, service.c_str(), &hints, &result)) { #ifdef __linux__ - res_init(); + res_init(); #endif - return INVALID_SOCKET; - } + return INVALID_SOCKET; + } - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket #ifdef _WIN32 - auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, - nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); - /** - * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 - * and above the socket creation fails on older Windows Systems. - * - * Let's try to create a socket the old way in this case. - * - * Reference: - * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa - * - * WSA_FLAG_NO_HANDLE_INHERIT: - * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with - * SP1, and later - * - */ - if (sock == INVALID_SOCKET) { - sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - } + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } #else - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); #endif - if (sock == INVALID_SOCKET) { continue; } + if (sock == INVALID_SOCKET) { + continue; + } #ifdef __linux__ - if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + continue; + } #endif - if (tcp_nodelay) { - int yes = 1; - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), - sizeof(yes)); - } + if (tcp_nodelay) { + int yes = 1; + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), + sizeof(yes)); + } - if (socket_options) { socket_options(sock); } + if (socket_options) { + socket_options(sock); + } - if (rp->ai_family == AF_INET6) { - int no = 0; - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), - sizeof(no)); - } + if (rp->ai_family == AF_INET6) { + int no = 0; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), + sizeof(no)); + } - // bind or connect - if (bind_or_connect(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + // bind or connect + if (bind_or_connect(sock, *rp)) { + freeaddrinfo(result); + return sock; + } - close_socket(sock); - } + close_socket(sock); + } - freeaddrinfo(result); - return INVALID_SOCKET; + freeaddrinfo(result); + return INVALID_SOCKET; } inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, - nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } inline bool bind_ip_address(socket_t sock, const char *host) { - struct addrinfo hints; - struct addrinfo *result; + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; - if (getaddrinfo(host, "0", &hints, &result)) { return false; } + if (getaddrinfo(host, "0", &hints, &result)) { + return false; + } - auto ret = false; - for (auto rp = result; rp; rp = rp->ai_next) { - const auto &ai = *rp; - if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - ret = true; - break; + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } } - } - freeaddrinfo(result); - return ret; + freeaddrinfo(result); + return ret; } #if !defined _WIN32 && !defined ANDROID @@ -1938,22 +2020,22 @@ inline bool bind_ip_address(socket_t sock, const char *host) { #ifdef USE_IF2IP inline std::string if2ip(const std::string &ifn) { - struct ifaddrs *ifap; - getifaddrs(&ifap); - for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { - if (ifa->ifa_addr && ifn == ifa->ifa_name) { - if (ifa->ifa_addr->sa_family == AF_INET) { - auto sa = reinterpret_cast(ifa->ifa_addr); - char buf[INET_ADDRSTRLEN]; - if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { - freeifaddrs(ifap); - return std::string(buf, INET_ADDRSTRLEN); - } - } - } - } - freeifaddrs(ifap); - return std::string(); + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); + return std::string(); } #endif @@ -1962,1346 +2044,1544 @@ inline socket_t create_client_socket(const char *host, int port, SocketOptions socket_options, time_t timeout_sec, time_t timeout_usec, const std::string &intf, Error &error) { - auto sock = create_socket( - host, port, 0, tcp_nodelay, socket_options, - [&](socket_t sock, struct addrinfo &ai) -> bool { - if (!intf.empty()) { + auto sock = create_socket( + host, port, 0, tcp_nodelay, socket_options, + [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { #ifdef USE_IF2IP - auto ip = if2ip(intf); - if (ip.empty()) { ip = intf; } - if (!bind_ip_address(sock, ip.c_str())) { - error = Error::BindIPAddress; - return false; - } + auto ip = if2ip(intf); + if (ip.empty()) { + ip = intf; + } + if (!bind_ip_address(sock, ip.c_str())) { + error = Error::BindIPAddress; + return false; + } #endif - } + } - set_nonblocking(sock, true); + set_nonblocking(sock, true); - auto ret = - ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); - if (ret < 0) { - if (is_connection_error() || - !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { - close_socket(sock); - error = Error::Connection; - return false; - } - } + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { + close_socket(sock); + error = Error::Connection; + return false; + } + } - set_nonblocking(sock, false); - error = Error::Success; - return true; - }); + set_nonblocking(sock, false); + error = Error::Success; + return true; + }); - if (sock != INVALID_SOCKET) { - error = Error::Success; - } else { - if (error == Error::Success) { error = Error::Connection; } - } + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } - return sock; + return sock; } inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { - if (addr.ss_family == AF_INET) { - port = ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - port = - ntohs(reinterpret_cast(&addr)->sin6_port); - } + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } - std::array ipstr{}; - if (!getnameinfo(reinterpret_cast(&addr), addr_len, - ipstr.data(), static_cast(ipstr.size()), nullptr, - 0, NI_NUMERICHOST)) { - ip = ipstr.data(); - } + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } } inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); - if (!getpeername(sock, reinterpret_cast(&addr), - &addr_len)) { - get_remote_ip_and_port(addr, addr_len, ip, port); - } + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } } inline const char * find_content_type(const std::string &path, const std::map &user_data) { - auto ext = file_extension(path); - - auto it = user_data.find(ext); - if (it != user_data.end()) { return it->second.c_str(); } - - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html" || ext == "htm") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "wasm") { - return "application/wasm"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second.c_str(); + } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; } inline const char *status_message(int status) { - switch (status) { - case 100: return "Continue"; - case 101: return "Switching Protocol"; - case 102: return "Processing"; - case 103: return "Early Hints"; - case 200: return "OK"; - case 201: return "Created"; - case 202: return "Accepted"; - case 203: return "Non-Authoritative Information"; - case 204: return "No Content"; - case 205: return "Reset Content"; - case 206: return "Partial Content"; - case 207: return "Multi-Status"; - case 208: return "Already Reported"; - case 226: return "IM Used"; - case 300: return "Multiple Choice"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 305: return "Use Proxy"; - case 306: return "unused"; - case 307: return "Temporary Redirect"; - case 308: return "Permanent Redirect"; - case 400: return "Bad Request"; - case 401: return "Unauthorized"; - case 402: return "Payment Required"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 405: return "Method Not Allowed"; - case 406: return "Not Acceptable"; - case 407: return "Proxy Authentication Required"; - case 408: return "Request Timeout"; - case 409: return "Conflict"; - case 410: return "Gone"; - case 411: return "Length Required"; - case 412: return "Precondition Failed"; - case 413: return "Payload Too Large"; - case 414: return "URI Too Long"; - case 415: return "Unsupported Media Type"; - case 416: return "Range Not Satisfiable"; - case 417: return "Expectation Failed"; - case 418: return "I'm a teapot"; - case 421: return "Misdirected Request"; - case 422: return "Unprocessable Entity"; - case 423: return "Locked"; - case 424: return "Failed Dependency"; - case 425: return "Too Early"; - case 426: return "Upgrade Required"; - case 428: return "Precondition Required"; - case 429: return "Too Many Requests"; - case 431: return "Request Header Fields Too Large"; - case 451: return "Unavailable For Legal Reasons"; - case 501: return "Not Implemented"; - case 502: return "Bad Gateway"; - case 503: return "Service Unavailable"; - case 504: return "Gateway Timeout"; - case 505: return "HTTP Version Not Supported"; - case 506: return "Variant Also Negotiates"; - case 507: return "Insufficient Storage"; - case 508: return "Loop Detected"; - case 510: return "Not Extended"; - case 511: return "Network Authentication Required"; - - default: - case 500: return "Internal Server Error"; - } + switch (status) { + case 100: + return "Continue"; + case 101: + return "Switching Protocol"; + case 102: + return "Processing"; + case 103: + return "Early Hints"; + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 203: + return "Non-Authoritative Information"; + case 204: + return "No Content"; + case 205: + return "Reset Content"; + case 206: + return "Partial Content"; + case 207: + return "Multi-Status"; + case 208: + return "Already Reported"; + case 226: + return "IM Used"; + case 300: + return "Multiple Choice"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 303: + return "See Other"; + case 304: + return "Not Modified"; + case 305: + return "Use Proxy"; + case 306: + return "unused"; + case 307: + return "Temporary Redirect"; + case 308: + return "Permanent Redirect"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 402: + return "Payment Required"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Payload Too Large"; + case 414: + return "URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Range Not Satisfiable"; + case 417: + return "Expectation Failed"; + case 418: + return "I'm a teapot"; + case 421: + return "Misdirected Request"; + case 422: + return "Unprocessable Entity"; + case 423: + return "Locked"; + case 424: + return "Failed Dependency"; + case 425: + return "Too Early"; + case 426: + return "Upgrade Required"; + case 428: + return "Precondition Required"; + case 429: + return "Too Many Requests"; + case 431: + return "Request Header Fields Too Large"; + case 451: + return "Unavailable For Legal Reasons"; + case 501: + return "Not Implemented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + case 506: + return "Variant Also Negotiates"; + case 507: + return "Insufficient Storage"; + case 508: + return "Loop Detected"; + case 510: + return "Not Extended"; + case 511: + return "Network Authentication Required"; + + default: + case 500: + return "Internal Server Error"; + } } inline bool can_compress_content_type(const std::string &content_type) { - return (!content_type.find("text/") && content_type != "text/event-stream") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; + return (!content_type.find("text/") && content_type != "text/event-stream") || + content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; } -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, + Gzip, + Brotli }; inline EncodingType encoding_type(const Request &req, const Response &res) { - auto ret = - detail::can_compress_content_type(res.get_header_value("Content-Type")); - if (!ret) { return EncodingType::None; } + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } - const auto &s = req.get_header_value("Accept-Encoding"); - (void)(s); + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); #ifdef CPPHTTPLIB_BROTLI_SUPPORT - // TODO: 'Accept-Encoding' has br, not br;q=0 - ret = s.find("br") != std::string::npos; - if (ret) { return EncodingType::Brotli; } + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 - ret = s.find("gzip") != std::string::npos; - if (ret) { return EncodingType::Gzip; } + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } #endif - return EncodingType::None; + return EncodingType::None; } class compressor { public: - virtual ~compressor(){}; + virtual ~compressor(){}; - typedef std::function Callback; - virtual bool compress(const char *data, size_t data_length, bool last, - Callback callback) = 0; + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; }; class decompressor { public: - virtual ~decompressor() {} + virtual ~decompressor() { + } - virtual bool is_valid() const = 0; + virtual bool is_valid() const = 0; - typedef std::function Callback; - virtual bool decompress(const char *data, size_t data_length, - Callback callback) = 0; + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; }; class nocompressor : public compressor { public: - ~nocompressor(){}; + ~nocompressor(){}; - bool compress(const char *data, size_t data_length, bool /*last*/, - Callback callback) override { - if (!data_length) { return true; } - return callback(data, data_length); - } + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override { + if (!data_length) { + return true; + } + return callback(data, data_length); + } }; #ifdef CPPHTTPLIB_ZLIB_SUPPORT class gzip_compressor : public compressor { public: - gzip_compressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, - Z_DEFAULT_STRATEGY) == Z_OK; - } + gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; + } - ~gzip_compressor() { deflateEnd(&strm_); } + ~gzip_compressor() { + deflateEnd(&strm_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - assert(is_valid_); + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + assert(is_valid_); - auto flush = last ? Z_FINISH : Z_NO_FLUSH; + auto flush = last ? Z_FINISH : Z_NO_FLUSH; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - int ret = Z_OK; + int ret = Z_OK; - std::array buff{}; - do { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + do { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = deflate(&strm_, flush); - assert(ret != Z_STREAM_ERROR); + ret = deflate(&strm_, flush); + assert(ret != Z_STREAM_ERROR); - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } while (strm_.avail_out == 0); + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); - assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); - assert(strm_.avail_in == 0); - return true; - } + assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); + assert(strm_.avail_in == 0); + return true; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; class gzip_decompressor : public decompressor { public: - gzip_decompressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - // 15 is the value of wbits, which should be at the maximum possible value - // to ensure that any gzip stream can be decoded. The offset of 32 specifies - // that the stream type should be automatically detected either gzip or - // deflate. - is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; - } + gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; + } - ~gzip_decompressor() { inflateEnd(&strm_); } + ~gzip_decompressor() { + inflateEnd(&strm_); + } - bool is_valid() const override { return is_valid_; } + bool is_valid() const override { + return is_valid_; + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - assert(is_valid_); + bool decompress(const char *data, size_t data_length, + Callback callback) override { + assert(is_valid_); - int ret = Z_OK; + int ret = Z_OK; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - std::array buff{}; - while (strm_.avail_in > 0) { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + while (strm_.avail_in > 0) { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = inflate(&strm_, Z_NO_FLUSH); - assert(ret != Z_STREAM_ERROR); - switch (ret) { - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: inflateEnd(&strm_); return false; - } + ret = inflate(&strm_, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } - return ret == Z_OK || ret == Z_STREAM_END; - } + return ret == Z_OK || ret == Z_STREAM_END; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; #endif #ifdef CPPHTTPLIB_BROTLI_SUPPORT class brotli_compressor : public compressor { public: - brotli_compressor() { - state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); - } + brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + } - ~brotli_compressor() { BrotliEncoderDestroyInstance(state_); } + ~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - std::array buff{}; + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } - auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; - auto available_in = data_length; - auto next_in = reinterpret_cast(data); + auto available_out = buff.size(); + auto next_out = buff.data(); - for (;;) { - if (last) { - if (BrotliEncoderIsFinished(state_)) { break; } - } else { - if (!available_in) { break; } - } - - auto available_out = buff.size(); - auto next_out = buff.data(); - - if (!BrotliEncoderCompressStream(state_, operation, &available_in, - &next_in, &available_out, &next_out, - nullptr)) { - return false; - } + if (!BrotliEncoderCompressStream(state_, operation, &available_in, + &next_in, &available_out, &next_out, + nullptr)) { + return false; + } - auto output_bytes = buff.size() - available_out; - if (output_bytes) { - callback(reinterpret_cast(buff.data()), output_bytes); - } - } + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } - return true; - } + return true; + } private: - BrotliEncoderState *state_ = nullptr; + BrotliEncoderState *state_ = nullptr; }; class brotli_decompressor : public decompressor { public: - brotli_decompressor() { - decoder_s = BrotliDecoderCreateInstance(0, 0, 0); - decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT - : BROTLI_DECODER_RESULT_ERROR; - } - - ~brotli_decompressor() { - if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } - } + brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; + } - bool is_valid() const override { return decoder_s; } + ~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_ERROR) { - return 0; + bool is_valid() const override { + return decoder_s; } - const uint8_t *next_in = (const uint8_t *)data; - size_t avail_in = data_length; - size_t total_out; + bool decompress(const char *data, size_t data_length, + Callback callback) override { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } - decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + const uint8_t *next_in = (const uint8_t *)data; + size_t avail_in = data_length; + size_t total_out; - std::array buff{}; - while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { - char *next_out = buff.data(); - size_t avail_out = buff.size(); + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; - decoder_r = BrotliDecoderDecompressStream( - decoder_s, &avail_in, &next_in, &avail_out, - reinterpret_cast(&next_out), &total_out); + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); - if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); - if (!callback(buff.data(), buff.size() - avail_out)) { return false; } - } + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } - return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; - } + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; + } private: - BrotliDecoderResult decoder_r; - BrotliDecoderState *decoder_s = nullptr; + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; }; #endif inline bool has_header(const Headers &headers, const char *key) { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline const char *get_header_value(const Headers &headers, const char *key, size_t id = 0, const char *def = nullptr) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second.c_str(); } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; } -template +template inline T get_header_value(const Headers & /*headers*/, const char * /*key*/, - size_t /*id*/ = 0, uint64_t /*def*/ = 0) {} + size_t /*id*/ = 0, uint64_t /*def*/ = 0) { +} -template <> +template<> inline uint64_t get_header_value(const Headers &headers, const char *key, size_t id, uint64_t def) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); - } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; } -template +template inline bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } - auto p = beg; - while (p < end && *p != ':') { - p++; - } + auto p = beg; + while (p < end && *p != ':') { + p++; + } - if (p == end) { return false; } + if (p == end) { + return false; + } - auto key_end = p; + auto key_end = p; - if (*p++ != ':') { return false; } + if (*p++ != ':') { + return false; + } - while (p < end && is_space_or_tab(*p)) { - p++; - } + while (p < end && is_space_or_tab(*p)) { + p++; + } - if (p < end) { - fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); - return true; - } + if (p < end) { + fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); + return true; + } - return false; + return false; } inline bool read_headers(Stream &strm, Headers &headers) { - const auto bufsiz = 2048; - char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); - for (;;) { - if (!line_reader.getline()) { return false; } + for (;;) { + if (!line_reader.getline()) { + return false; + } - // Check if the line ends with CRLF. - if (line_reader.end_with_crlf()) { - // Blank line indicates end of headers. - if (line_reader.size() == 2) { break; } - } else { - continue; // Skip invalid line. - } + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } + } else { + continue; // Skip invalid line. + } - // Exclude CRLF - auto end = line_reader.ptr() + line_reader.size() - 2; + // Exclude CRLF + auto end = line_reader.ptr() + line_reader.size() - 2; - parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - headers.emplace(std::move(key), std::move(val)); - }); - } + parse_header(line_reader.ptr(), end, + [&](std::string &&key, std::string &&val) { + headers.emplace(std::move(key), std::move(val)); + }); + } - return true; + return true; } inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return false; } + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } - if (!out(buf, static_cast(n))) { return false; } + if (!out(buf, static_cast(n))) { + return false; + } - r += static_cast(n); + r += static_cast(n); - if (progress) { - if (!progress(r, len)) { return false; } + if (progress) { + if (!progress(r, len)) { + return false; + } + } } - } - return true; + return true; } inline void skip_content_with_length(Stream &strm, uint64_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } } inline bool read_content_without_length(Stream &strm, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - for (;;) { - auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n < 0) { - return false; - } else if (n == 0) { - return true; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { + return false; + } } - if (!out(buf, static_cast(n))) { return false; } - } - return true; + return true; } inline bool read_content_chunked(Stream &strm, ContentReceiver out) { - const auto bufsiz = 16; - char buf[bufsiz]; + const auto bufsiz = 16; + char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + stream_line_reader line_reader(strm, buf, bufsiz); - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - unsigned long chunk_len; - while (true) { - char *end_ptr; + unsigned long chunk_len; + while (true) { + char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); - if (end_ptr == line_reader.ptr()) { return false; } - if (chunk_len == ULONG_MAX) { return false; } + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } - if (chunk_len == 0) { break; } + if (chunk_len == 0) { + break; + } - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { - return false; - } + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - if (strcmp(line_reader.ptr(), "\r\n")) { break; } + if (strcmp(line_reader.ptr(), "\r\n")) { + break; + } - if (!line_reader.getline()) { return false; } - } + if (!line_reader.getline()) { + return false; + } + } - if (chunk_len == 0) { - // Reader terminator after chunks - if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) - return false; - } + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } - return true; + return true; } inline bool is_chunked_transfer_encoding(const Headers &headers) { - return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), - "chunked"); + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); } -template +template bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, bool decompress, U callback) { - if (decompress) { - std::string encoding = x.get_header_value("Content-Encoding"); - std::shared_ptr decompressor; + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::shared_ptr decompressor; - if (encoding.find("gzip") != std::string::npos || - encoding.find("deflate") != std::string::npos) { + if (encoding.find("gzip") != std::string::npos || + encoding.find("deflate") != std::string::npos) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } else if (encoding.find("br") != std::string::npos) { + } else if (encoding.find("br") != std::string::npos) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } + } - if (decompressor) { - if (decompressor->is_valid()) { - ContentReceiver out = [&](const char *buf, size_t n) { - return decompressor->decompress( - buf, n, - [&](const char *buf, size_t n) { return receiver(buf, n); }); - }; - return callback(out); - } else { - status = 500; - return false; - } + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiver out = [&](const char *buf, size_t n) { + return decompressor->decompress( + buf, n, + [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + return callback(out); + } else { + status = 500; + return false; + } + } } - } - ContentReceiver out = [&](const char *buf, size_t n) { - return receiver(buf, n); - }; - return callback(out); + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + return callback(out); } -template +template bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, ContentReceiver receiver, bool decompress) { - return prepare_content_receiver( - x, status, receiver, decompress, [&](const ContentReceiver &out) { - auto ret = true; - auto exceed_payload_max_length = false; - - if (is_chunked_transfer_encoding(x.headers)) { - ret = read_content_chunked(strm, out); - } else if (!has_header(x.headers, "Content-Length")) { - ret = read_content_without_length(strm, out); - } else { - auto len = get_header_value(x.headers, "Content-Length"); - if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; - } else if (len > 0) { - ret = read_content_with_length(strm, len, progress, out); - } - } + return prepare_content_receiver( + x, status, receiver, decompress, [&](const ContentReceiver &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value(x.headers, "Content-Length"); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } - if (!ret) { status = exceed_payload_max_length ? 413 : 400; } - return ret; - }); + if (!ret) { + status = exceed_payload_max_length ? 413 : 400; + } + return ret; + }); } -template +template inline ssize_t write_headers(Stream &strm, const T &info, const Headers &headers) { - ssize_t write_len = 0; - for (const auto &x : info.headers) { - if (x.first == "EXCEPTION_WHAT") { continue; } - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } - for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } + ssize_t write_len = 0; + for (const auto &x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { + continue; + } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } write_len += len; - } - auto len = strm.write("\r\n"); - if (len < 0) { return len; } - write_len += len; - return write_len; + return write_len; } inline bool write_data(Stream &strm, const char *d, size_t l) { - size_t offset = 0; - while (offset < l) { - auto length = strm.write(d + offset, l - offset); - if (length < 0) { return false; } - offset += static_cast(length); - } - return true; + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; } -template +template inline ssize_t write_content(Stream &strm, ContentProvider content_provider, size_t offset, size_t length, T is_shutting_down) { - size_t begin_offset = offset; - size_t end_offset = offset + length; - auto ok = true; - DataSink data_sink; + size_t begin_offset = offset; + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (offset < end_offset && !is_shutting_down()) { - if (!content_provider(offset, end_offset - offset, data_sink)) { - return -1; + while (offset < end_offset && !is_shutting_down()) { + if (!content_provider(offset, end_offset - offset, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } } - if (!ok) { return -1; } - } - return static_cast(offset - begin_offset); + return static_cast(offset - begin_offset); } -template +template inline ssize_t write_content_without_length(Stream &strm, ContentProvider content_provider, T is_shutting_down) { - size_t offset = 0; - auto data_available = true; - auto ok = true; - DataSink data_sink; + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.done = [&](void) { data_available = false; }; + data_sink.done = [&](void) { data_available = false; }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } - return static_cast(offset); + return static_cast(offset); } -template +template inline ssize_t write_content_chunked(Stream &strm, ContentProvider content_provider, T is_shutting_down, U &compressor) { - size_t offset = 0; - auto data_available = true; - ssize_t total_written_length = 0; - auto ok = true; - DataSink data_sink; - - data_sink.write = [&](const char *d, size_t l) { - if (!ok) { return; } - - data_available = l > 0; - offset += l; - - std::string payload; - if (!compressor.compress(d, l, false, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - }; - - data_sink.done = [&](void) { - if (!ok) { return; } - - data_available = false; - - std::string payload; - if (!compressor.compress(nullptr, 0, true, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - - static const std::string done_marker("0\r\n\r\n"); - if (write_data(strm, done_marker.data(), done_marker.size())) { - total_written_length += done_marker.size(); - } else { - ok = false; - } - }; + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + auto ok = true; + DataSink data_sink; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.write = [&](const char *d, size_t l) { + if (!ok) { + return; + } + + data_available = l > 0; + offset += l; + + std::string payload; + if (!compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } + }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + data_sink.done = [&](void) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } - return total_written_length; + static const std::string done_marker("0\r\n\r\n"); + if (write_data(strm, done_marker.data(), done_marker.size())) { + total_written_length += done_marker.size(); + } else { + ok = false; + } + }; + + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } + + return total_written_length; } -template +template inline bool redirect(T &cli, const Request &req, Response &res, const std::string &path) { - Request new_req = req; - new_req.path = path; - new_req.redirect_count -= 1; - - if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { - new_req.method = "GET"; - new_req.body.clear(); - new_req.headers.clear(); - } + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } - Response new_res; + Response new_res; - auto ret = cli.send(new_req, new_res); - if (ret) { res = new_res; } - return ret; + auto ret = cli.send(new_req, new_res); + if (ret) { + res = new_res; + } + return ret; } inline std::string params_to_query_str(const Params ¶ms) { - std::string query; + std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } - query += it->first; - query += "="; - query += encode_url(it->second); - } - return query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_url(it->second); + } + return query; } inline void parse_query_text(const std::string &s, Params ¶ms) { - split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char *b2, const char *e2) { - if (key.empty()) { - key.assign(b2, e2); - } else { - val.assign(b2, e2); - } - }); + split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); - if (!key.empty()) { - params.emplace(decode_url(key, true), decode_url(val, true)); - } - }); + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); } inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { return false; } - boundary = content_type.substr(pos + 9); - if (boundary.length() >= 2 && boundary.front() == '"' && - boundary.back() == '"') { - boundary = boundary.substr(1, boundary.size() - 2); - } - return !boundary.empty(); + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { + return false; + } + boundary = content_type.substr(pos + 9); + if (boundary.length() >= 2 && boundary.front() == '"' && + boundary.back() == '"') { + boundary = boundary.substr(1, boundary.size() - 2); + } + return !boundary.empty(); } inline bool parse_range_header(const std::string &s, Ranges &ranges) { - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re_first_range)) { - auto pos = static_cast(m.position(1)); - auto len = static_cast(m.length(1)); - bool all_valid_ranges = true; - split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { - if (!all_valid_ranges) return; - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch cm; - if (std::regex_match(b, e, cm, re_another_range)) { - ssize_t first = -1; - if (!cm.str(1).empty()) { - first = static_cast(std::stoll(cm.str(1))); - } - - ssize_t last = -1; - if (!cm.str(2).empty()) { - last = static_cast(std::stoll(cm.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - all_valid_ranges = false; - return; - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); - return all_valid_ranges; - } - return false; + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; } class MultipartFormDataParser { public: - MultipartFormDataParser() = default; - - void set_boundary(std::string &&boundary) { boundary_ = boundary; } - - bool is_valid() const { return is_valid_; } - - template - bool parse(const char *buf, size_t n, T content_callback, U header_callback) { - - static const std::regex re_content_disposition( - "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" - "\"(.*?)\")?\\s*$", - std::regex_constants::icase); - static const std::string dash_ = "--"; - static const std::string crlf_ = "\r\n"; - - buf_.append(buf, n); // TODO: performance improvement - - while (!buf_.empty()) { - switch (state_) { - case 0: { // Initial boundary - auto pattern = dash_ + boundary_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - auto pos = buf_.find(pattern); - if (pos != 0) { return false; } - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - state_ = 1; - break; - } - case 1: { // New entry - clear_file_info(); - state_ = 2; - break; - } - case 2: { // Headers - auto pos = buf_.find(crlf_); - while (pos != std::string::npos) { - // Empty line - if (pos == 0) { - if (!header_callback(file_)) { - is_valid_ = false; - return false; - } - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 3; - break; - } - - static const std::string header_name = "content-type:"; - const auto header = buf_.substr(0, pos); - if (start_with(header, header_name)) { - file_.content_type = trim_copy(header.substr(header_name.size())); - } else { - std::smatch m; - if (std::regex_match(header, m, re_content_disposition)) { - file_.name = m[1]; - file_.filename = m[2]; - } - } - - buf_.erase(0, pos + crlf_.size()); - off_ += pos + crlf_.size(); - pos = buf_.find(crlf_); - } - if (state_ != 3) { return true; } - break; - } - case 3: { // Body - { - auto pattern = crlf_ + dash_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos == std::string::npos) { - pos = buf_.size(); - while (pos > 0) { - auto c = buf_[pos - 1]; - if (c != '\r' && c != '\n' && c != '-') { break; } - pos--; - } - } + MultipartFormDataParser() = default; - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; - } + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + } - off_ += pos; - buf_.erase(0, pos); - } + bool is_valid() const { + return is_valid_; + } - { - auto pattern = crlf_ + dash_ + boundary_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos != std::string::npos) { - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + static const std::string dash_ = "--"; + static const std::string crlf_ = "\r\n"; + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + auto pos = buf_.find(pattern); + if (pos != 0) { + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + static const std::string header_name = "content-type:"; + const auto header = buf_.substr(0, pos); + if (start_with(header, header_name)) { + file_.content_type = trim_copy(header.substr(header_name.size())); + } else { + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { + pos = buf_.size(); + while (pos > 0) { + auto c = buf_[pos - 1]; + if (c != '\r' && c != '\n' && c != '-') { + break; + } + pos--; + } + } + + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, crlf_.size(), crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, pattern.size(), pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; } - - off_ += pos + pattern.size(); - buf_.erase(0, pos + pattern.size()); - state_ = 4; - } else { - if (!content_callback(buf_.data(), pattern.size())) { - is_valid_ = false; - return false; } - - off_ += pattern.size(); - buf_.erase(0, pattern.size()); - } - } - break; - } - case 4: { // Boundary - if (crlf_.size() > buf_.size()) { return true; } - if (buf_.compare(0, crlf_.size(), crlf_) == 0) { - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 1; - } else { - auto pattern = dash_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - if (buf_.compare(0, pattern.size(), pattern) == 0) { - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - is_valid_ = true; - state_ = 5; - } else { - return true; - } } - break; - } - case 5: { // Done - is_valid_ = false; - return false; - } - } - } - return true; - } + return true; + } private: - void clear_file_info() { - file_.name.clear(); - file_.filename.clear(); - file_.content_type.clear(); - } - - std::string boundary_; - - std::string buf_; - size_t state_ = 0; - bool is_valid_ = false; - size_t off_ = 0; - MultipartFormData file_; + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + bool is_valid_ = false; + size_t off_ = 0; + MultipartFormData file_; }; inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; } inline std::string make_multipart_data_boundary() { - static const char data[] = - "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - std::random_device seed_gen; - std::mt19937 engine(seed_gen()); + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); - std::string result = "--cpp-httplib-multipart-data-"; + std::string result = "--cpp-httplib-multipart-data-"; - for (auto i = 0; i < 16; i++) { - result += data[engine() % (sizeof(data) - 1)]; - } + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } - return result; + return result; } inline std::pair get_range_offset_and_length(const Request &req, size_t content_length, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.first == -1 && r.second == -1) { - return std::make_pair(0, content_length); - } + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } - auto slen = static_cast(content_length); + auto slen = static_cast(content_length); - if (r.first == -1) { - r.first = slen - r.second; - r.second = slen - 1; - } + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } - if (r.second == -1) { r.second = slen - 1; } + if (r.second == -1) { + r.second = slen - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) { - std::string field = "bytes "; - field += std::to_string(offset); - field += "-"; - field += std::to_string(offset + length - 1); - field += "/"; - field += std::to_string(content_length); - return field; + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; } -template +template bool process_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type, SToken stoken, CToken ctoken, Content content) { - for (size_t i = 0; i < req.ranges.size(); i++) { - ctoken("--"); - stoken(boundary); - ctoken("\r\n"); - if (!content_type.empty()) { - ctoken("Content-Type: "); - stoken(content_type); - ctoken("\r\n"); - } + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } - auto offsets = get_range_offset_and_length(req, res.body.size(), i); - auto offset = offsets.first; - auto length = offsets.second; + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; - ctoken("Content-Range: "); - stoken(make_content_range_header_field(offset, length, res.body.size())); - ctoken("\r\n"); - ctoken("\r\n"); - if (!content(offset, length)) { return false; } - ctoken("\r\n"); - } + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { + return false; + } + ctoken("\r\n"); + } - ctoken("--"); - stoken(boundary); - ctoken("--\r\n"); + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); - return true; + return true; } inline std::string make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - std::string data; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data += token; }, - [&](const char *token) { data += token; }, - [&](size_t offset, size_t length) { - data += res.body.substr(offset, length); - return true; - }); + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); - return data; + return data; } inline size_t get_multipart_ranges_data_length(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - size_t data_length = 0; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data_length += token.size(); }, - [&](const char *token) { data_length += strlen(token); }, - [&](size_t /*offset*/, size_t length) { - data_length += length; - return true; - }); + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); - return data_length; + return data_length; } -template +template inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type, T is_shutting_down) { - return process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { strm.write(token); }, - [&](const char *token) { strm.write(token); }, - [&](size_t offset, size_t length) { - return write_content(strm, res.content_provider_, offset, length, - is_shutting_down) >= 0; - }); + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down) >= 0; + }); } inline std::pair get_range_offset_and_length(const Request &req, const Response &res, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.second == -1) { - r.second = static_cast(res.content_length_) - 1; - } + if (r.second == -1) { + r.second = static_cast(res.content_length_) - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline bool expect_content(const Request &req) { - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || - req.method == "PRI" || req.method == "DELETE") { - return true; - } - // TODO: check if Content-Length is set - return false; + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; } inline bool has_crlf(const char *s) { - auto p = s; - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; - } - return false; + auto p = s; + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -template +template inline std::string message_digest(const std::string &s, Init init, Update update, Final final, size_t digest_length) { - using namespace std; + using namespace std; - std::vector md(digest_length, 0); - CTX ctx; - init(&ctx); - update(&ctx, s.data(), s.size()); - final(md.data(), &ctx); + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); - stringstream ss; - for (auto c : md) { - ss << setfill('0') << setw(2) << hex << (unsigned int)c; - } - return ss.str(); + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); } inline std::string MD5(const std::string &s) { - return message_digest(s, MD5_Init, MD5_Update, MD5_Final, - MD5_DIGEST_LENGTH); + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); } inline std::string SHA_256(const std::string &s) { - return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, - SHA256_DIGEST_LENGTH); + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); } inline std::string SHA_512(const std::string &s) { - return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, - SHA512_DIGEST_LENGTH); + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); } #endif @@ -3310,37 +3590,41 @@ inline std::string SHA_512(const std::string &s) { // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store inline bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } + if (!hStore) { + return false; + } - PCCERT_CONTEXT pContext = NULL; - while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); + PCCERT_CONTEXT pContext = NULL; + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } } - } - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); - return true; + return true; } #endif class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { WSACleanup(); } + ~WSInit() { + WSACleanup(); + } }; static WSInit wsinit_; @@ -3351,340 +3635,355 @@ inline std::pair make_digest_authentication_header( const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, const std::string &username, const std::string &password, bool is_proxy = false) { - using namespace std; + using namespace std; - string nc; - { - stringstream ss; - ss << setfill('0') << setw(8) << hex << cnonce_count; - nc = ss.str(); - } + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } - auto qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else { - qop = "auth"; - } + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } - string response; - { - auto H = algo == "SHA-256" - ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : + detail::MD5; - auto A1 = username + ":" + auth.at("realm") + ":" + password; + auto A1 = username + ":" + auth.at("realm") + ":" + password; - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + - "\", response=\"" + response + "\""; + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + + "\", response=\"" + response + "\""; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } #endif inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { - auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; - if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); - auto s = res.get_header_value(auth_key); - auto pos = s.find(' '); - if (pos != std::string::npos) { - auto type = s.substr(0, pos); - if (type == "Basic") { - return false; - } else if (type == "Digest") { - s = s.substr(pos + 1); - auto beg = std::sregex_iterator(s.begin(), s.end(), re); - for (auto i = beg; i != std::sregex_iterator(); ++i) { - auto m = *i; - auto key = s.substr(static_cast(m.position(1)), - static_cast(m.length(1))); - auto val = m.length(2) > 0 - ? s.substr(static_cast(m.position(2)), - static_cast(m.length(2))) - : s.substr(static_cast(m.position(3)), - static_cast(m.length(3))); - auth[key] = val; + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) : + s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } } - return true; - } } - } - return false; + return false; } // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 inline std::string random_string(size_t length) { - auto randchar = []() -> char { - const char charset[] = "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[static_cast(rand()) % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; } class ContentProviderAdapter { public: - explicit ContentProviderAdapter( - ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) { + } - bool operator()(size_t offset, size_t, DataSink &sink) { - return content_provider_(offset, sink); - } + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } private: - ContentProviderWithoutLength content_provider_; + ContentProviderWithoutLength content_provider_; }; -} // namespace detail +} // namespace detail // Header utilities inline std::pair make_range_header(Ranges ranges) { - std::string field = "bytes="; - auto i = 0; - for (auto r : ranges) { - if (i != 0) { field += ", "; } - if (r.first != -1) { field += std::to_string(r.first); } - field += '-'; - if (r.second != -1) { field += std::to_string(r.second); } - i++; - } - return std::make_pair("Range", field); + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", field); } inline std::pair make_basic_authentication_header(const std::string &username, const std::string &password, bool is_proxy = false) { - auto field = "Basic " + detail::base64_encode(username + ":" + password); - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } inline std::pair make_bearer_token_authentication_header(const std::string &token, bool is_proxy = false) { - auto field = "Bearer " + token; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } // Request implementation inline bool Request::has_header(const char *key) const { - return detail::has_header(headers, key); + return detail::has_header(headers, key); } inline std::string Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Request::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Request::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Request::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline bool Request::has_param(const char *key) const { - return params.find(key) != params.end(); + return params.find(key) != params.end(); } inline std::string Request::get_param_value(const char *key, size_t id) const { - auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); } inline size_t Request::get_param_value_count(const char *key) const { - auto r = params.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline bool Request::is_multipart_form_data() const { - const auto &content_type = get_header_value("Content-Type"); - return !content_type.find("multipart/form-data"); + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); } inline bool Request::has_file(const char *key) const { - return files.find(key) != files.end(); + return files.find(key) != files.end(); } inline MultipartFormData Request::get_file_value(const char *key) const { - auto it = files.find(key); - if (it != files.end()) { return it->second; } - return MultipartFormData(); + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); } // Response implementation inline bool Response::has_header(const char *key) const { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline std::string Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Response::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Response::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Response::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline void Response::set_redirect(const char *url, int stat) { - if (!detail::has_crlf(url)) { - set_header("Location", url); - if (300 <= stat && stat < 400) { - this->status = stat; - } else { - this->status = 302; + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = 302; + } } - } } inline void Response::set_redirect(const std::string &url, int stat) { - set_redirect(url.c_str(), stat); + set_redirect(url.c_str(), stat); } inline void Response::set_content(const char *s, size_t n, const char *content_type) { - body.assign(s, n); - set_header("Content-Type", content_type); + body.assign(s, n); + set_header("Content-Type", content_type); } inline void Response::set_content(std::string s, const char *content_type) { - body = std::move(s); - set_header("Content-Type", content_type); + body = std::move(s); + set_header("Content-Type", content_type); } inline void Response::set_content_provider(size_t in_length, const char *content_type, ContentProvider provider, const std::function &resource_releaser) { - assert(in_length > 0); - set_header("Content-Type", content_type); - content_length_ = in_length; - content_provider_ = std::move(provider); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + assert(in_length > 0); + set_header("Content-Type", content_type); + content_length_ = in_length; + content_provider_ = std::move(provider); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_content_provider(const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_chunked_content_provider( const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = true; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = true; } // Rstream implementation inline ssize_t Stream::write(const char *ptr) { - return write(ptr, strlen(ptr)); + return write(ptr, strlen(ptr)); } inline ssize_t Stream::write(const std::string &s) { - return write(s.data(), s.size()); + return write(s.data(), s.size()); } -template -inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { - const auto bufsiz = 2048; - std::array buf; +template +inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { + const auto bufsiz = 2048; + std::array buf; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); + auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); #else - auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); #endif - if (sn <= 0) { return sn; } + if (sn <= 0) { + return sn; + } - auto n = static_cast(sn); + auto n = static_cast(sn); - if (n >= buf.size() - 1) { - std::vector glowable_buf(buf.size()); + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); - while (n >= glowable_buf.size() - 1) { - glowable_buf.resize(glowable_buf.size() * 2); + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), - glowable_buf.size() - 1, fmt, - args...)); + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); #else - n = static_cast( - snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); #endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); } - return write(&glowable_buf[0], n); - } else { - return write(buf.data(), n); - } } namespace detail { @@ -3697,75 +3996,88 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) {} + write_timeout_usec_(write_timeout_usec) { +} -inline SocketStream::~SocketStream() {} +inline SocketStream::~SocketStream() { +} inline bool SocketStream::is_readable() const { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SocketStream::is_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; } inline ssize_t SocketStream::read(char *ptr, size_t size) { - if (!is_readable()) { return -1; } + if (!is_readable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return recv(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return recv(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); #endif } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!is_writable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return send(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return send(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); #endif } inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - return detail::get_remote_ip_and_port(sock_, ip, port); + return detail::get_remote_ip_and_port(sock_, ip, port); } // Buffer stream implementation -inline bool BufferStream::is_readable() const { return true; } +inline bool BufferStream::is_readable() const { + return true; +} -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_writable() const { + return true; +} inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER <= 1900 - auto len_read = buffer._Copy_s(ptr, size, size, position); + auto len_read = buffer._Copy_s(ptr, size, size, position); #else - auto len_read = buffer.copy(ptr, size, position); + auto len_read = buffer.copy(ptr, size, position); #endif - position += static_cast(len_read); - return static_cast(len_read); + position += static_cast(len_read); + return static_cast(len_read); } inline ssize_t BufferStream::write(const char *ptr, size_t size) { - buffer.append(ptr, size); - return static_cast(size); + buffer.append(ptr, size); + return static_cast(size); } inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, - int & /*port*/) const {} + int & /*port*/) const { +} -inline const std::string &BufferStream::get_buffer() const { return buffer; } +inline const std::string &BufferStream::get_buffer() const { + return buffer; +} -} // namespace detail +} // namespace detail // HTTP server implementation inline Server::Server() @@ -3773,1182 +4085,1279 @@ inline Server::Server() [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }), svr_sock_(INVALID_SOCKET), is_running_(false) { #ifdef __linux__ - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif } -inline Server::~Server() {} +inline Server::~Server() { +} inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline bool Server::set_base_dir(const char *dir, const char *mount_point) { - return set_mount_point(mount_point, dir); + return set_mount_point(mount_point, dir); } inline bool Server::set_mount_point(const char *mount_point, const char *dir) { - if (detail::is_dir(dir)) { - std::string mnt = mount_point ? mount_point : "/"; - if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.emplace_back(mnt, dir); - return true; + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } } - } - return false; + return false; } inline bool Server::remove_mount_point(const char *mount_point) { - for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { - if (it->first == mount_point) { - base_dirs_.erase(it); - return true; + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } } - } - return false; + return false; } inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, const char *mime) { - file_extension_and_mimetype_map_[ext] = mime; + file_extension_and_mimetype_map_[ext] = mime; } inline void Server::set_file_request_handler(Handler handler) { - file_request_handler_ = std::move(handler); + file_request_handler_ = std::move(handler); } inline void Server::set_error_handler(Handler handler) { - error_handler_ = std::move(handler); + error_handler_ = std::move(handler); } -inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void Server::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } +inline void Server::set_logger(Logger logger) { + logger_ = std::move(logger); +} inline void Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { - expect_100_continue_handler_ = std::move(handler); + expect_100_continue_handler_ = std::move(handler); } inline void Server::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; + keep_alive_max_count_ = count; } inline void Server::set_keep_alive_timeout(time_t sec) { - keep_alive_timeout_sec_ = sec; + keep_alive_timeout_sec_ = sec; } inline void Server::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void Server::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void Server::set_idle_interval(time_t sec, time_t usec) { - idle_interval_sec_ = sec; - idle_interval_usec_ = usec; + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; } inline void Server::set_payload_max_length(size_t length) { - payload_max_length_ = length; + payload_max_length_ = length; } inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { - if (bind_internal(host, port, socket_flags) < 0) return false; - return true; + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; } inline int Server::bind_to_any_port(const char *host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + return bind_internal(host, 0, socket_flags); } -inline bool Server::listen_after_bind() { return listen_internal(); } +inline bool Server::listen_after_bind() { + return listen_internal(); +} inline bool Server::listen(const char *host, int port, int socket_flags) { - return bind_to_port(host, port, socket_flags) && listen_internal(); + return bind_to_port(host, port, socket_flags) && listen_internal(); } -inline bool Server::is_running() const { return is_running_; } +inline bool Server::is_running() const { + return is_running_; +} inline void Server::stop() { - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); - detail::shutdown_socket(sock); - detail::close_socket(sock); - } + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } } inline bool Server::parse_request_line(const char *s, Request &req) { - const static std::regex re( - "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " - "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[5]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3], false); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { detail::parse_query_text(m[4], req.params); } + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { + detail::parse_query_text(m[4], req.params); + } - return true; - } + return true; + } - return false; + return false; } inline bool Server::write_response(Stream &strm, bool close_connection, const Request &req, Response &res) { - assert(res.status != -1); + assert(res.status != -1); - if (400 <= res.status && error_handler_) { error_handler_(req, res); } - - detail::BufferStream bstrm; + if (400 <= res.status && error_handler_) { + error_handler_(req, res); + } - // Response line - if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, - detail::status_message(res.status))) { - return false; - } + detail::BufferStream bstrm; - // Headers - if (close_connection || req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } else { - std::stringstream ss; - ss << "timeout=" << keep_alive_timeout_sec_ - << ", max=" << keep_alive_max_count_; - res.set_header("Keep-Alive", ss.str()); - } + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } - if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { - res.set_header("Content-Type", "text/plain"); - } + // Headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::stringstream ss; + ss << "timeout=" << keep_alive_timeout_sec_ + << ", max=" << keep_alive_max_count_; + res.set_header("Keep-Alive", ss.str()); + } - if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { - res.set_header("Accept-Ranges", "bytes"); - } + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + res.set_header("Content-Type", "text/plain"); + } - std::string content_type; - std::string boundary; + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } - if (req.ranges.size() > 1) { - boundary = detail::make_multipart_data_boundary(); + std::string content_type; + std::string boundary; - auto it = res.headers.find("Content-Type"); - if (it != res.headers.end()) { - content_type = it->second; - res.headers.erase(it); - } + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); - res.headers.emplace("Content-Type", - "multipart/byteranges; boundary=" + boundary); - } + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } - auto type = detail::encoding_type(req, res); + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } - if (res.body.empty()) { - if (res.content_length_ > 0) { - size_t length = 0; - if (req.ranges.empty()) { - length = res.content_length_; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.content_length_); - res.set_header("Content-Range", content_range); - } else { - length = detail::get_multipart_ranges_data_length(req, res, boundary, - content_type); - } - res.set_header("Content-Length", std::to_string(length)); - } else { - if (res.content_provider_) { - if (res.is_chunked_content_provider) { - res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - } - } - } else { - res.set_header("Content-Length", "0"); - } - } - } else { - if (req.ranges.empty()) { - ; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.body.size(), 0); - auto offset = offsets.first; - auto length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.body.size()); - res.set_header("Content-Range", content_range); - res.body = res.body.substr(offset, length); + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } else { + res.set_header("Content-Length", "0"); + } + } } else { - res.body = - detail::make_multipart_ranges_data(req, res, boundary, content_type); - } + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } - if (type != detail::EncodingType::None) { - std::shared_ptr compressor; + if (type != detail::EncodingType::None) { + std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "gzip"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "gzip"); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "brotli"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "brotli"); #endif - } + } + + if (compressor) { + std::string compressed; - if (compressor) { - std::string compressed; + if (!compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + return false; + } - if (!compressor->compress(res.body.data(), res.body.size(), true, - [&](const char *data, size_t data_len) { - compressed.append(data, data_len); - return true; - })) { - return false; + res.body.swap(compressed); + } } - res.body.swap(compressed); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length); - } - - if (!detail::write_headers(bstrm, res, Headers())) { return false; } + if (!detail::write_headers(bstrm, res, Headers())) { + return false; + } - // Flush buffer - auto &data = bstrm.get_buffer(); - strm.write(data.data(), data.size()); + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); - // Body - auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!strm.write(res.body)) { ret = false; } - } else if (res.content_provider_) { - if (!write_content_with_provider(strm, req, res, boundary, - content_type)) { - ret = false; - } + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { + ret = false; + } + } else if (res.content_provider_) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + ret = false; + } + } } - } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return ret; + return ret; } inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - auto is_shutting_down = [this]() { - return this->svr_sock_ == INVALID_SOCKET; - }; - - if (res.content_length_ > 0) { - if (req.ranges.empty()) { - if (detail::write_content(strm, res.content_provider_, 0, - res.content_length_, is_shutting_down) < 0) { - return false; - } - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - auto length = offsets.second; - if (detail::write_content(strm, res.content_provider_, offset, length, - is_shutting_down) < 0) { - return false; - } + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider_, offset, length, + is_shutting_down) < 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, is_shutting_down)) { + return false; + } + } } else { - if (!detail::write_multipart_ranges_data( - strm, req, res, boundary, content_type, is_shutting_down)) { - return false; - } - } - } else { - if (res.is_chunked_content_provider) { - auto type = detail::encoding_type(req, res); + if (res.is_chunked_content_provider) { + auto type = detail::encoding_type(req, res); - std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + std::shared_ptr compressor; + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else { - compressor = std::make_shared(); - } - assert(compressor != nullptr); + } else { + compressor = std::make_shared(); + } + assert(compressor != nullptr); - if (detail::write_content_chunked(strm, res.content_provider_, - is_shutting_down, *compressor) < 0) { - return false; - } - } else { - if (detail::write_content_without_length(strm, res.content_provider_, - is_shutting_down) < 0) { - return false; - } + if (detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor) < 0) { + return false; + } + } else { + if (detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down) < 0) { + return false; + } + } } - } - return true; + return true; } inline bool Server::read_content(Stream &strm, Request &req, Response &res) { - MultipartFormDataMap::iterator cur; - if (read_content_core( - strm, req, res, - // Regular - [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } - req.body.append(buf, n); - return true; - }, - // Multipart - [&](const MultipartFormData &file) { - cur = req.files.emplace(file.name, file); - return true; - }, - [&](const char *buf, size_t n) { - auto &content = cur->second.content; - if (content.size() + n > content.max_size()) { return false; } - content.append(buf, n); - return true; - })) { - const auto &content_type = req.get_header_value("Content-Type"); - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); + MultipartFormDataMap::iterator cur; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + return true; } - return true; - } - return false; + return false; } inline bool Server::read_content_with_content_receiver( Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) { - return read_content_core(strm, req, res, receiver, multipart_header, - multipart_receiver); + return read_content_core(strm, req, res, receiver, multipart_header, + multipart_receiver); } inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader mulitpart_header, ContentReceiver multipart_receiver) { - detail::MultipartFormDataParser multipart_form_data_parser; - ContentReceiver out; + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } - if (req.is_multipart_form_data()) { - const auto &content_type = req.get_header_value("Content-Type"); - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary)) { - res.status = 400; - return false; - } - - multipart_form_data_parser.set_boundary(std::move(boundary)); - out = [&](const char *buf, size_t n) { - /* For debug - size_t pos = 0; - while (pos < n) { - auto read_size = std::min(1, n - pos); - auto ret = multipart_form_data_parser.parse( - buf + pos, read_size, multipart_receiver, mulitpart_header); - if (!ret) { return false; } - pos += read_size; - } - return true; - */ - return multipart_form_data_parser.parse(buf, n, multipart_receiver, - mulitpart_header); - }; - } else { - out = receiver; - } + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = std::min(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, mulitpart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { - return true; - } + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } - if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, - out, true)) { - return false; - } + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } - if (req.is_multipart_form_data()) { - if (!multipart_form_data_parser.is_valid()) { - res.status = 400; - return false; + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } } - } - return true; + return true; } inline bool Server::handle_file_request(Request &req, Response &res, bool head) { - for (const auto &kv : base_dirs_) { - const auto &mount_point = kv.first; - const auto &base_dir = kv.second; - - // Prefix match - if (!req.path.compare(0, mount_point.size(), mount_point)) { - std::string sub_path = "/" + req.path.substr(mount_point.size()); - if (detail::is_valid_path(sub_path)) { - auto path = base_dir + sub_path; - if (path.back() == '/') { path += "index.html"; } - - if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = - detail::find_content_type(path, file_extension_and_mimetype_map_); - if (type) { res.set_header("Content-Type", type); } - res.status = 200; - if (!head && file_request_handler_) { - file_request_handler_(req, res); - } - return true; - } - } - } - } - return false; + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.compare(0, mount_point.size(), mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { + res.set_header("Content-Type", type); + } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; } inline socket_t Server::create_server_socket(const char *host, int port, int socket_flags, SocketOptions socket_options) const { - return detail::create_socket( - host, port, socket_flags, tcp_nodelay_, socket_options, - [](socket_t sock, struct addrinfo &ai) -> bool { - if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }); + return detail::create_socket( + host, port, socket_flags, tcp_nodelay_, socket_options, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }); } inline int Server::bind_internal(const char *host, int port, int socket_flags) { - if (!is_valid()) { return -1; } - - svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); - if (svr_sock_ == INVALID_SOCKET) { return -1; } + if (!is_valid()) { + return -1; + } - if (port == 0) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - if (getsockname(svr_sock_, reinterpret_cast(&addr), - &addr_len) == -1) { - return -1; + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; } - if (addr.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - return ntohs(reinterpret_cast(&addr)->sin6_port); + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } } else { - return -1; + return port; } - } else { - return port; - } } inline bool Server::listen_internal() { - auto ret = true; - is_running_ = true; + auto ret = true; + is_running_ = true; - { - std::unique_ptr task_queue(new_task_queue()); + { + std::unique_ptr task_queue(new_task_queue()); - while (svr_sock_ != INVALID_SOCKET) { + while (svr_sock_ != INVALID_SOCKET) { #ifdef __linux__ - if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { #endif - auto val = detail::select_read(svr_sock_, idle_interval_sec_, - idle_interval_usec_); - if (val == 0) { // Timeout - task_queue->on_idle(); - continue; - } + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } #ifdef __linux__ - } + } #endif - socket_t sock = accept(svr_sock_, nullptr, nullptr); - - if (sock == INVALID_SOCKET) { - if (errno == EMFILE) { - // The per-process limit of open file descriptors has been reached. - // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; - } + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } #if __cplusplus > 201703L - task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); #else - task_queue->enqueue([=]() { process_and_close_socket(sock); }); + task_queue->enqueue([=]() { process_and_close_socket(sock); }); #endif - } + } - task_queue->shutdown(); - } + task_queue->shutdown(); + } - is_running_ = false; - return ret; + is_running_ = false; + return ret; } inline bool Server::routing(Request &req, Response &res, Stream &strm) { - // File handler - bool is_head_request = req.method == "HEAD"; - if ((req.method == "GET" || is_head_request) && - handle_file_request(req, res, is_head_request)) { - return true; - } + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } - if (detail::expect_content(req)) { - // Content reader handler - { - ContentReader reader( - [&](ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, receiver, - nullptr, nullptr); - }, - [&](MultipartContentHeader header, ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, nullptr, - header, receiver); - }); - - if (req.method == "POST") { - if (dispatch_request_for_content_reader( - req, res, reader, post_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PUT") { - if (dispatch_request_for_content_reader( - req, res, reader, put_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PATCH") { - if (dispatch_request_for_content_reader( - req, res, reader, patch_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "DELETE") { - if (dispatch_request_for_content_reader( - req, res, reader, delete_handlers_for_content_reader_)) { - return true; - } - } - } - - // Read content into `req.body` - if (!read_content(strm, req, res)) { return false; } - } - - // Regular handler - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } else if (req.method == "PATCH") { - return dispatch_request(req, res, patch_handlers_); - } - - res.status = 400; - return false; + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, receiver, + nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, reader, delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; } inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { - try { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + try { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; - } + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } catch (const std::exception &ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } - } catch (const std::exception &ex) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", ex.what()); - } catch (...) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", "UNKNOWN"); - } - return false; + return false; } inline bool Server::dispatch_request_for_content_reader( Request &req, Response &res, ContentReader content_reader, const HandlersForContentReader &handlers) { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res, content_reader); - return true; + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } } - } - return false; + return false; } inline bool Server::process_request(Stream &strm, bool close_connection, bool &connection_closed, const std::function &setup_request) { - std::array buf{}; + std::array buf{}; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - // Connection has been closed on client - if (!line_reader.getline()) { return false; } + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } - Request req; - Response res; + Request req; + Response res; - res.version = "HTTP/1.1"; + res.version = "HTTP/1.1"; - // Check if the request URI doesn't exceed the limit - if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = 414; - return write_response(strm, close_connection, req, res); - } + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, close_connection, req, res); + } - // Request line and headers - if (!parse_request_line(line_reader.ptr(), req) || - !detail::read_headers(strm, req.headers)) { - res.status = 400; - return write_response(strm, close_connection, req, res); - } + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, close_connection, req, res); + } - if (req.get_header_value("Connection") == "close") { - connection_closed = true; - } + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } - if (req.version == "HTTP/1.0" && - req.get_header_value("Connection") != "Keep-Alive") { - connection_closed = true; - } + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } - strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); - req.set_header("REMOTE_ADDR", req.remote_addr); - req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); - if (req.has_header("Range")) { - const auto &range_header_value = req.get_header_value("Range"); - if (!detail::parse_range_header(range_header_value, req.ranges)) { - // TODO: error + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } } - } - - if (setup_request) { setup_request(req); } - if (req.get_header_value("Expect") == "100-continue") { - auto status = 100; - if (expect_100_continue_handler_) { - status = expect_100_continue_handler_(req, res); + if (setup_request) { + setup_request(req); } - switch (status) { - case 100: - case 417: - strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, - detail::status_message(status)); - break; - default: return write_response(strm, close_connection, req, res); + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: + return write_response(strm, close_connection, req, res); + } } - } - // Rounting - if (routing(req, res, strm)) { - if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } - } else { - if (res.status == -1) { res.status = 404; } - } + // Rounting + if (routing(req, res, strm)) { + if (res.status == -1) { + res.status = req.ranges.empty() ? 200 : 206; + } + } else { + if (res.status == -1) { + res.status = 404; + } + } - return write_response(strm, close_connection, req, res); + return write_response(strm, close_connection, req, res); } -inline bool Server::is_valid() const { return true; } +inline bool Server::is_valid() const { + return true; +} inline bool Server::process_and_close_socket(socket_t sock) { - auto ret = detail::process_server_socket( - sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - nullptr); - }); + auto ret = detail::process_server_socket( + sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + nullptr); + }); - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; } // HTTP client implementation inline ClientImpl::ClientImpl(const std::string &host) - : ClientImpl(host, 80, std::string(), std::string()) {} + : ClientImpl(host, 80, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port) - : ClientImpl(host, port, std::string(), std::string()) {} + : ClientImpl(host, port, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), - client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + client_cert_path_(client_cert_path), client_key_path_(client_key_path) { +} -inline ClientImpl::~ClientImpl() { stop_core(); } +inline ClientImpl::~ClientImpl() { + stop_core(); +} -inline bool ClientImpl::is_valid() const { return true; } +inline bool ClientImpl::is_valid() const { + return true; +} -inline Error ClientImpl::get_last_error() const { return error_; } +inline Error ClientImpl::get_last_error() const { + return error_; +} inline socket_t ClientImpl::create_client_socket() const { - if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, interface_, error_); + } return detail::create_client_socket( - proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + host_.c_str(), port_, tcp_nodelay_, socket_options_, connection_timeout_sec_, connection_timeout_usec_, interface_, error_); - } - return detail::create_client_socket( - host_.c_str(), port_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, interface_, error_); } inline bool ClientImpl::create_and_connect_socket(Socket &socket) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } - socket.sock = sock; - return true; + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { + return false; + } + socket.sock = sock; + return true; } inline void ClientImpl::close_socket(Socket &socket, bool /*process_socket_ret*/) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - socket_.ssl = nullptr; + socket_.ssl = nullptr; #endif } inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { - std::array buf; + std::array buf; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - if (!line_reader.getline()) { return false; } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - - std::cmatch m; - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); - res.reason = std::string(m[3]); + if (!line_reader.getline()) { + return false; + } - // Ignore '100 Continue' - while (res.status == 100) { - if (!line_reader.getline()) { return false; } // CRLF - if (!line_reader.getline()) { return false; } // next response line + const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); res.reason = std::string(m[3]); - } - return true; + // Ignore '100 Continue' + while (res.status == 100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; } inline bool ClientImpl::send(const Request &req, Response &res) { - std::lock_guard request_mutex_guard(request_mutex_); + std::lock_guard request_mutex_guard(request_mutex_); - { - std::lock_guard guard(socket_mutex_); + { + std::lock_guard guard(socket_mutex_); - auto is_alive = false; - if (socket_.is_open()) { - is_alive = detail::select_write(socket_.sock, 0, 0) > 0; - if (!is_alive) { close_socket(socket_, false); } - } + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::select_write(socket_.sock, 0, 0) > 0; + if (!is_alive) { + close_socket(socket_, false); + } + } - if (!is_alive) { - if (!create_and_connect_socket(socket_)) { return false; } + if (!is_alive) { + if (!create_and_connect_socket(socket_)) { + return false; + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // TODO: refactoring - if (is_ssl()) { - auto &scli = static_cast(*this); - if (!proxy_host_.empty() && proxy_port_ != -1) { - bool success = false; - if (!scli.connect_with_proxy(socket_, res, success)) { - return success; - } - } - - if (!scli.initialize_ssl(socket_)) { return false; } - } + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + bool success = false; + if (!scli.connect_with_proxy(socket_, res, success)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_)) { + return false; + } + } #endif + } } - } - auto close_connection = !keep_alive_; + auto close_connection = !keep_alive_; - auto ret = process_socket(socket_, [&](Stream &strm) { - return handle_request(strm, req, res, close_connection); - }); + auto ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection); + }); - if (close_connection || !ret) { stop_core(); } + if (close_connection || !ret) { + stop_core(); + } - if (!ret) { - if (error_ == Error::Success) { error_ = Error::Unknown; } - } + if (!ret) { + if (error_ == Error::Success) { + error_ = Error::Unknown; + } + } - return ret; + return ret; } inline bool ClientImpl::handle_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - if (req.path.empty()) { - error_ = Error::Connection; - return false; - } + if (req.path.empty()) { + error_ = Error::Connection; + return false; + } - bool ret; + bool ret; - if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { - auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; - ret = process_request(strm, req2, res, close_connection); - } else { - ret = process_request(strm, req, res, close_connection); - } + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection); + } else { + ret = process_request(strm, req, res, close_connection); + } - if (!ret) { return false; } + if (!ret) { + return false; + } - if (300 < res.status && res.status < 400 && follow_location_) { - ret = redirect(req, res); - } + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if ((res.status == 401 || res.status == 407) && - req.authorization_count_ < 5) { - auto is_proxy = res.status == 407; - const auto &username = - is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; - const auto &password = - is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; - - if (!username.empty() && !password.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res, auth, is_proxy)) { - Request new_req = req; - new_req.authorization_count_ += 1; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - new_req.headers.erase(key); - new_req.headers.insert(detail::make_digest_authentication_header( - req, auth, new_req.authorization_count_, detail::random_string(10), - username, password, is_proxy)); - - Response new_res; - - ret = send(new_req, new_res); - if (ret) { res = new_res; } - } - } - } + if ((res.status == 401 || res.status == 407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { + res = new_res; + } + } + } + } #endif - return ret; + return ret; } inline bool ClientImpl::redirect(const Request &req, Response &res) { - if (req.redirect_count == 0) { - error_ = Error::ExceedRedirectCount; - return false; - } + if (req.redirect_count == 0) { + error_ = Error::ExceedRedirectCount; + return false; + } - auto location = detail::decode_url(res.get_header_value("location"), true); - if (location.empty()) { return false; } + auto location = detail::decode_url(res.get_header_value("location"), true); + if (location.empty()) { + return false; + } - const static std::regex re( - R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + const static std::regex re( + R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } - auto scheme = is_ssl() ? "https" : "http"; + auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - auto port_str = m[3].str(); - auto next_path = m[4].str(); + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); - auto next_port = port_; - if (!port_str.empty()) { - next_port = std::stoi(port_str); - } else if (!next_scheme.empty()) { - next_port = next_scheme == "https" ? 443 : 80; - } + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } - if (next_scheme.empty()) { next_scheme = scheme; } - if (next_host.empty()) { next_host = host_; } - if (next_path.empty()) { next_path = "/"; } + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } - if (next_scheme == scheme && next_host == host_ && next_port == port_) { - return detail::redirect(*this, req, res, next_path); - } else { - if (next_scheme == "https") { + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSLClient cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; #else - return false; + return false; #endif - } else { - ClientImpl cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + } else { + ClientImpl cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; + } } - } } inline bool ClientImpl::write_request(Stream &strm, const Request &req, bool close_connection) { - detail::BufferStream bstrm; - - // Request line - const auto &path = detail::encode_url(req.path); + detail::BufferStream bstrm; - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + // Request line + const auto &path = detail::encode_url(req.path); - // Additonal headers - Headers headers; - if (close_connection) { headers.emplace("Connection", "close"); } + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - if (!req.has_header("Host")) { - if (is_ssl()) { - if (port_ == 443) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } - } else { - if (port_ == 80) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } + // Additonal headers + Headers headers; + if (close_connection) { + headers.emplace("Connection", "close"); } - } - - if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } - if (!req.has_header("User-Agent")) { - headers.emplace("User-Agent", "cpp-httplib/0.7"); - } - - if (req.body.empty()) { - if (req.content_provider) { - auto length = std::to_string(req.content_length); - headers.emplace("Content-Length", length); - } else { - headers.emplace("Content-Length", "0"); - } - } else { - if (!req.has_header("Content-Type")) { - headers.emplace("Content-Type", "text/plain"); + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } } - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - headers.emplace("Content-Length", length); + if (!req.has_header("Accept")) { + headers.emplace("Accept", "*/*"); } - } - if (!basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - basic_auth_username_, basic_auth_password_, false)); - } + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.7"); + } - if (!proxy_basic_auth_username_.empty() && - !proxy_basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - proxy_basic_auth_username_, proxy_basic_auth_password_, true)); - } + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } - if (!bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - bearer_token_auth_token_, false)); - } + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } - if (!proxy_bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - proxy_bearer_token_auth_token_, true)); - } + if (!basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } - detail::write_headers(bstrm, req, headers); + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } - // Flush buffer - auto &data = bstrm.get_buffer(); - if (!detail::write_data(strm, data.data(), data.size())) { - error_ = Error::Write; - return false; - } + if (!bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } - // Body - if (req.body.empty()) { - if (req.content_provider) { - size_t offset = 0; - size_t end_offset = req.content_length; + if (!proxy_bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } - bool ok = true; + detail::write_headers(bstrm, req, headers); - DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - if (detail::write_data(strm, d, l)) { - offset += l; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error_ = Error::Write; + return false; + } - while (offset < end_offset) { - if (!req.content_provider(offset, end_offset - offset, data_sink)) { - error_ = Error::Canceled; - return false; - } - if (!ok) { - error_ = Error::Write; - return false; + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + bool ok = true; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + if (detail::write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (offset < end_offset) { + if (!req.content_provider(offset, end_offset - offset, data_sink)) { + error_ = Error::Canceled; + return false; + } + if (!ok) { + error_ = Error::Write; + return false; + } + } } - } + } else { + return detail::write_data(strm, req.body.data(), req.body.size()); } - } else { - return detail::write_data(strm, req.body.data(), req.body.size()); - } - return true; + return true; } inline std::shared_ptr ClientImpl::send_with_content_provider( @@ -4956,541 +5365,572 @@ inline std::shared_ptr ClientImpl::send_with_content_provider( const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type) { - Request req; - req.method = method; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = method; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - if (content_type) { req.headers.emplace("Content-Type", content_type); } + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } #ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_) { - detail::gzip_compressor compressor; - - if (content_provider) { - auto ok = true; - size_t offset = 0; - - DataSink data_sink; - data_sink.write = [&](const char *data, size_t data_len) { - if (ok) { - auto last = offset + data_len == content_length; - - auto ret = compressor.compress( - data, data_len, last, [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - }); - - if (ret) { - offset += data_len; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && true; }; - - while (ok && offset < content_length) { - if (!content_provider(offset, content_length - offset, data_sink)) { - error_ = Error::Canceled; - return nullptr; + if (compress_) { + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && true; }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error_ = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body.data(), body.size(), true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + return nullptr; + } } - } - } else { - if (!compressor.compress(body.data(), body.size(), true, - [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - })) { - return nullptr; - } - } - req.headers.emplace("Content-Encoding", "gzip"); - } else + req.headers.emplace("Content-Encoding", "gzip"); + } else #endif - { - if (content_provider) { - req.content_length = content_length; - req.content_provider = content_provider; - } else { - req.body = body; + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } } - } - auto res = std::make_shared(); + auto res = std::make_shared(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } inline bool ClientImpl::process_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - // Send request - if (!write_request(strm, req, close_connection)) { return false; } + // Send request + if (!write_request(strm, req, close_connection)) { + return false; + } - // Receive response and headers - if (!read_response_line(strm, res) || - !detail::read_headers(strm, res.headers)) { - error_ = Error::Read; - return false; - } + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + error_ = Error::Read; + return false; + } - if (req.response_handler) { - if (!req.response_handler(res)) { - error_ = Error::Canceled; - return false; + if (req.response_handler) { + if (!req.response_handler(res)) { + error_ = Error::Canceled; + return false; + } } - } - // Body - if (req.method != "HEAD" && req.method != "CONNECT") { - auto out = - req.content_receiver - ? static_cast([&](const char *buf, size_t n) { + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + auto out = + req.content_receiver ? static_cast([&](const char *buf, size_t n) { auto ret = req.content_receiver(buf, n); - if (!ret) { error_ = Error::Canceled; } + if (!ret) { + error_ = Error::Canceled; + } return ret; - }) - : static_cast([&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { return false; } - res.body.append(buf, n); + }) : + static_cast([&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress) { return true; - }); - - auto progress = [&](uint64_t current, uint64_t total) { - if (!req.progress) { return true; } - auto ret = req.progress(current, total); - if (!ret) { error_ = Error::Canceled; } - return ret; - }; + } + auto ret = req.progress(current, total); + if (!ret) { + error_ = Error::Canceled; + } + return ret; + }; - int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, progress, out, decompress_)) { - if (error_ != Error::Canceled) { error_ = Error::Read; } - return false; + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, progress, out, decompress_)) { + if (error_ != Error::Canceled) { + error_ = Error::Read; + } + return false; + } } - } - if (res.get_header_value("Connection") == "close" || - (res.version == "HTTP/1.0" && res.reason != "Connection established")) { - stop_core(); - } + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + stop_core(); + } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return true; + return true; } inline bool ClientImpl::process_socket(Socket &socket, std::function callback) { - return detail::process_client_socket(socket.sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, callback); + return detail::process_client_socket(socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, callback); } -inline bool ClientImpl::is_ssl() const { return false; } +inline bool ClientImpl::is_ssl() const { + return false; +} inline Result ClientImpl::Get(const char *path) { - return Get(path, Headers(), Progress()); + return Get(path, Headers(), Progress()); } inline Result ClientImpl::Get(const char *path, Progress progress) { - return Get(path, Headers(), std::move(progress)); + return Get(path, Headers(), std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers) { - return Get(path, headers, Progress()); + return Get(path, headers, Progress()); } inline Result ClientImpl::Get(const char *path, const Headers &headers, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver) { - return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return Get(path, headers, nullptr, std::move(content_receiver), nullptr); + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return Get(path, headers, nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - nullptr); + return Get(path, Headers(), std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, headers, std::move(response_handler), content_receiver, - nullptr); + return Get(path, headers, std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - progress); + return Get(path, Headers(), std::move(response_handler), content_receiver, + progress); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.response_handler = std::move(response_handler); - req.content_receiver = std::move(content_receiver); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Head(const char *path) { - return Head(path, Headers()); + return Head(path, Headers()); } inline Result ClientImpl::Head(const char *path, const Headers &headers) { - Request req; - req.method = "HEAD"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "HEAD"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Post(const char *path) { - return Post(path, std::string(), nullptr); + return Post(path, std::string(), nullptr); } inline Result ClientImpl::Post(const char *path, const std::string &body, const char *content_type) { - return Post(path, Headers(), body, content_type); + return Post(path, Headers(), body, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Params ¶ms) { - return Post(path, Headers(), params); + return Post(path, Headers(), params); } inline Result ClientImpl::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Post(path, Headers(), content_length, content_provider, content_type); + return Post(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Post(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Post(const char *path, const MultipartFormDataItems &items) { - return Post(path, Headers(), items); + return Post(path, Headers(), items); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - auto boundary = detail::make_multipart_data_boundary(); + auto boundary = detail::make_multipart_data_boundary(); - std::string body; + std::string body; - for (const auto &item : items) { - body += "--" + boundary + "\r\n"; - body += "Content-Disposition: form-data; name=\"" + item.name + "\""; - if (!item.filename.empty()) { - body += "; filename=\"" + item.filename + "\""; - } - body += "\r\n"; - if (!item.content_type.empty()) { - body += "Content-Type: " + item.content_type + "\r\n"; + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; } - body += "\r\n"; - body += item.content + "\r\n"; - } - body += "--" + boundary + "--\r\n"; + body += "--" + boundary + "--\r\n"; - std::string content_type = "multipart/form-data; boundary=" + boundary; - return Post(path, headers, body, content_type.c_str()); + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); } inline Result ClientImpl::Put(const char *path) { - return Put(path, std::string(), nullptr); + return Put(path, std::string(), nullptr); } inline Result ClientImpl::Put(const char *path, const std::string &body, const char *content_type) { - return Put(path, Headers(), body, content_type); + return Put(path, Headers(), body, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Put(path, Headers(), content_length, content_provider, content_type); + return Put(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, const Params ¶ms) { - return Put(path, Headers(), params); + return Put(path, Headers(), params); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Put(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Patch(const char *path, const std::string &body, const char *content_type) { - return Patch(path, Headers(), body, content_type); + return Patch(path, Headers(), body, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, body, 0, - nullptr, content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, body, 0, + nullptr, content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Patch(path, Headers(), content_length, content_provider, content_type); + return Patch(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Delete(const char *path) { - return Delete(path, Headers(), std::string(), nullptr); + return Delete(path, Headers(), std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const std::string &body, const char *content_type) { - return Delete(path, Headers(), body, content_type); + return Delete(path, Headers(), body, content_type); } inline Result ClientImpl::Delete(const char *path, const Headers &headers) { - return Delete(path, headers, std::string(), nullptr); + return Delete(path, headers, std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - Request req; - req.method = "DELETE"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; - - if (content_type) { req.headers.emplace("Content-Type", content_type); } - req.body = body; + Request req; + req.method = "DELETE"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; + + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } + req.body = body; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Options(const char *path) { - return Options(path, Headers()); + return Options(path, Headers()); } inline Result ClientImpl::Options(const char *path, const Headers &headers) { - Request req; - req.method = "OPTIONS"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "OPTIONS"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline size_t ClientImpl::is_socket_open() const { - std::lock_guard guard(socket_mutex_); - return socket_.is_open(); + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); } inline void ClientImpl::stop() { - stop_core(); - error_ = Error::Canceled; + stop_core(); + error_ = Error::Canceled; } inline void ClientImpl::stop_core() { - std::lock_guard guard(socket_mutex_); - if (socket_.is_open()) { - detail::shutdown_socket(socket_.sock); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - close_socket(socket_, true); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { + detail::shutdown_socket(socket_.sock); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + close_socket(socket_, true); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } } inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { - connection_timeout_sec_ = sec; - connection_timeout_usec_ = usec; + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; } inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void ClientImpl::set_basic_auth(const char *username, const char *password) { - basic_auth_username_ = username; - basic_auth_password_ = password; + basic_auth_username_ = username; + basic_auth_password_ = password; } inline void ClientImpl::set_bearer_token_auth(const char *token) { - bearer_token_auth_token_ = token; + bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_digest_auth(const char *username, const char *password) { - digest_auth_username_ = username; - digest_auth_password_ = password; + digest_auth_username_ = username; + digest_auth_password_ = password; } #endif -inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } +inline void ClientImpl::set_keep_alive(bool on) { + keep_alive_ = on; +} -inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } +inline void ClientImpl::set_follow_location(bool on) { + follow_location_ = on; +} inline void ClientImpl::set_default_headers(Headers headers) { - default_headers_ = std::move(headers); + default_headers_ = std::move(headers); } -inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void ClientImpl::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void ClientImpl::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void ClientImpl::set_compress(bool on) { compress_ = on; } +inline void ClientImpl::set_compress(bool on) { + compress_ = on; +} -inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } +inline void ClientImpl::set_decompress(bool on) { + decompress_ = on; +} -inline void ClientImpl::set_interface(const char *intf) { interface_ = intf; } +inline void ClientImpl::set_interface(const char *intf) { + interface_ = intf; +} inline void ClientImpl::set_proxy(const char *host, int port) { - proxy_host_ = host; - proxy_port_ = port; + proxy_host_ = host; + proxy_port_ = port; } inline void ClientImpl::set_proxy_basic_auth(const char *username, const char *password) { - proxy_basic_auth_username_ = username; - proxy_basic_auth_password_ = password; + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; } inline void ClientImpl::set_proxy_bearer_token_auth(const char *token) { - proxy_bearer_token_auth_token_ = token; + proxy_bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_proxy_digest_auth(const char *username, const char *password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } #endif inline void ClientImpl::set_logger(Logger logger) { - logger_ = std::move(logger); + logger_ = std::move(logger); } /* @@ -5499,66 +5939,66 @@ inline void ClientImpl::set_logger(Logger logger) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT namespace detail { -template +template inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } - if (ssl) { - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + if (ssl) { + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - return nullptr; + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + return nullptr; + } } - } - return ssl; + return ssl; } inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, bool process_socket_ret) { - if (process_socket_ret) { - SSL_shutdown(ssl); // shutdown only if not already closed by remote - } + if (process_socket_ret) { + SSL_shutdown(ssl); // shutdown only if not already closed by remote + } - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); } -template +template inline bool process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } #if OPENSSL_VERSION_NUMBER < 0x10100000L @@ -5566,49 +6006,51 @@ static std::shared_ptr> openSSL_locks_; class SSLThreadLocks { public: - SSLThreadLocks() { - openSSL_locks_ = - std::make_shared>(CRYPTO_num_locks()); - CRYPTO_set_locking_callback(locking_callback); - } + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } - ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + ~SSLThreadLocks() { + CRYPTO_set_locking_callback(nullptr); + } private: - static void locking_callback(int mode, int type, const char * /*file*/, - int /*line*/) { - auto &lk = (*openSSL_locks_)[static_cast(type)]; - if (mode & CRYPTO_LOCK) { - lk.lock(); - } else { - lk.unlock(); + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &lk = (*openSSL_locks_)[static_cast(type)]; + if (mode & CRYPTO_LOCK) { + lk.lock(); + } else { + lk.unlock(); + } } - } }; #endif class SSLInit { public: - SSLInit() { + SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - SSL_load_error_strings(); - SSL_library_init(); + SSL_load_error_strings(); + SSL_library_init(); #else - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); #endif - } + } - ~SSLInit() { + ~SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - ERR_free_strings(); + ERR_free_strings(); #endif - } + } private: #if OPENSSL_VERSION_NUMBER < 0x10100000L - SSLThreadLocks thread_init_; + SSLThreadLocks thread_init_; #endif }; @@ -5622,839 +6064,904 @@ inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec) { - { - timeval tv; - tv.tv_sec = static_cast(read_timeout_sec); - tv.tv_usec = static_cast(read_timeout_usec); + { + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } - { - timeval tv; - tv.tv_sec = static_cast(write_timeout_sec); - tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } + { + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } } -inline SSLSocketStream::~SSLSocketStream() {} +inline SSLSocketStream::~SSLSocketStream() { +} inline bool SSLSocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SSLSocketStream::is_writable() const { - return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > - 0; + return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > + 0; } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0 || is_readable()) { - return SSL_read(ssl_, ptr, static_cast(size)); - } - return -1; + if (SSL_pending(ssl_) > 0 || is_readable()) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } - return -1; + if (is_writable()) { + return SSL_write(ssl_, ptr, static_cast(size)); + } + return -1; } inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); + detail::get_remote_ip_and_port(sock_, ip, port); } static SSLInit sslinit_; -} // namespace detail +} // namespace detail // SSL HTTP server implementation inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, const char *client_ca_cert_dir_path) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - // if (client_ca_cert_file_path) { - // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); - // SSL_CTX_set_client_CA_list(ctx_, list); - // } - - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLServer::is_valid() const { return ctx_; } +inline bool SSLServer::is_valid() const { + return ctx_; +} inline bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, - [](SSL * /*ssl*/) { return true; }); - - if (ssl) { - auto ret = detail::process_server_socket_ssl( - ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [this, ssl](Stream &strm, bool close_connection, - bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - detail::ssl_delete(ctx_mutex_, ssl, ret); - return ret; - } + auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, + [](SSL * /*ssl*/) { return true; }); + + if (ssl) { + auto ret = detail::process_server_socket_ssl( + ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [this, ssl](Stream &strm, bool close_connection, + bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + detail::ssl_delete(ctx_mutex_, ssl, ret); + return ret; + } - detail::close_socket(sock); - return false; + detail::close_socket(sock); + return false; } // SSL HTTP client implementation inline SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} + : SSLClient(host, 443, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} + : SSLClient(host, port, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key) : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (client_cert != nullptr && client_key != nullptr) { - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLClient::is_valid() const { return ctx_; } +inline bool SSLClient::is_valid() const { + return ctx_; +} inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } - if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } + if (ca_cert_file_path) { + ca_cert_file_path_ = ca_cert_file_path; + } + if (ca_cert_dir_path) { + ca_cert_dir_path_ = ca_cert_dir_path; + } } inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { ca_cert_store_ = ca_cert_store; } + if (ca_cert_store) { + ca_cert_store_ = ca_cert_store; + } } inline void SSLClient::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; + server_certificate_verification_ = enabled; } inline long SSLClient::get_openssl_verify_result() const { - return verify_result_; + return verify_result_; } -inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } +inline SSL_CTX *SSLClient::ssl_context() const { + return ctx_; +} inline bool SSLClient::create_and_connect_socket(Socket &socket) { - return is_valid() && ClientImpl::create_and_connect_socket(socket); + return is_valid() && ClientImpl::create_and_connect_socket(socket); } inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success) { - success = true; - Response res2; - - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = host_and_port_; - return process_request(strm, req2, res2, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - - if (res2.status == 407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res2, auth, true)) { - Response res3; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = host_and_port_; - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - return process_request(strm, req3, res3, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - } - } else { - res = res2; - return false; + success = true; + Response res2; + + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + } + } else { + res = res2; + return false; + } } - } - return true; + return true; } inline bool SSLClient::load_certs() { - bool ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - ret = false; - } - } else if (ca_cert_store_ != nullptr) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { - SSL_CTX_set_cert_store(ctx_, ca_cert_store_); - } - } else { + bool ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else if (ca_cert_store_ != nullptr) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { + SSL_CTX_set_cert_store(ctx_, ca_cert_store_); + } + } else { #ifdef _WIN32 - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); #else SSL_CTX_set_default_verify_paths(ctx_); #endif - } - }); + } + }); - return ret; + return ret; } inline bool SSLClient::initialize_ssl(Socket &socket) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl) { - if (server_certificate_verification_) { - if (!load_certs()) { - error_ = Error::SSLLoadingCerts; - return false; - } - SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); - } + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (server_certificate_verification_) { + if (!load_certs()) { + error_ = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); + } - if (SSL_connect(ssl) != 1) { - error_ = Error::SSLConnection; - return false; - } + if (SSL_connect(ssl) != 1) { + error_ = Error::SSLConnection; + return false; + } - if (server_certificate_verification_) { - verify_result_ = SSL_get_verify_result(ssl); + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); - if (verify_result_ != X509_V_OK) { - error_ = Error::SSLServerVerification; - return false; - } + if (verify_result_ != X509_V_OK) { + error_ = Error::SSLServerVerification; + return false; + } - auto server_cert = SSL_get_peer_certificate(ssl); + auto server_cert = SSL_get_peer_certificate(ssl); - if (server_cert == nullptr) { - error_ = Error::SSLServerVerification; - return false; - } + if (server_cert == nullptr) { + error_ = Error::SSLServerVerification; + return false; + } - if (!verify_host(server_cert)) { - X509_free(server_cert); - error_ = Error::SSLServerVerification; - return false; - } - X509_free(server_cert); - } + if (!verify_host(server_cert)) { + X509_free(server_cert); + error_ = Error::SSLServerVerification; + return false; + } + X509_free(server_cert); + } - return true; - }, - [&](SSL *ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - return true; - }); + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }); - if (ssl) { - socket.ssl = ssl; - return true; - } + if (ssl) { + socket.ssl = ssl; + return true; + } - close_socket(socket, false); - return false; + close_socket(socket, false); + return false; } inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); - socket_.ssl = nullptr; - } + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); + socket_.ssl = nullptr; + } } inline bool SSLClient::process_socket(Socket &socket, std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, callback); + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, callback); } -inline bool SSLClient::is_ssl() const { return true; } +inline bool SSLClient::is_ssl() const { + return true; +} inline bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" + /* Quote from RFC2818 section 3.1 "Server Identity" - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); } inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; + auto ret = false; - auto type = GEN_DNS; + auto type = GEN_DNS; - struct in6_addr addr6; - struct in_addr addr; - size_t addr_len = 0; + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; #ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); - } + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_mached = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (val->type == type) { - auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); - auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); - - if (strlen(name) == name_len) { - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || - !memcmp(&addr, name, addr_len)) { - ip_mached = true; + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } } - break; - } } - } - } - if (dsn_matched || ip_mached) { ret = true; } - } + if (dsn_matched || ip_mached) { + ret = true; + } + } - GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); - return ret; + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + return ret; } inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); + const auto subject_name = X509_get_subject_name(server_cert); - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } } - } - return false; + return false; } inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { - if (host_.size() == pattern_len && host_ == pattern) { return true; } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(std::string(b, e)); - }); + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); - if (host_components_.size() != pattern_components.size()) { return false; } + if (host_components_.size() != pattern_components.size()) { + return false; + } - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (p != h && p != "*") { - auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && - !p.compare(0, p.size() - 1, h)); - if (!partial_match) { return false; } + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; } - ++itr; - } - return true; + return true; } #endif // Universal client implementation inline Client::Client(const char *scheme_host_port) - : Client(scheme_host_port, std::string(), std::string()) {} + : Client(scheme_host_port, std::string(), std::string()) { +} inline Client::Client(const char *scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); + const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); - std::cmatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); + std::cmatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (!scheme.empty() && (scheme != "http" && scheme != "https")) { + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else - if (!scheme.empty() && scheme != "http") { + if (!scheme.empty() && scheme != "http") { #endif - return; - } + return; + } - auto is_ssl = scheme == "https"; + auto is_ssl = scheme == "https"; - auto host = m[2].str(); + auto host = m[2].str(); - auto port_str = m[3].str(); - auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + auto port_str = m[3].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); - if (is_ssl) { + if (is_ssl) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); - is_ssl_ = is_ssl; + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; #endif + } else { + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + } } else { - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); + cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, + client_key_path); } - } else { - cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, - client_key_path); - } } inline Client::Client(const std::string &host, int port) - : cli_(std::make_shared(host, port)) {} + : cli_(std::make_shared(host, port)) { +} inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : cli_(std::make_shared(host, port, client_cert_path, - client_key_path)) {} + client_key_path)) { +} -inline Client::~Client() {} +inline Client::~Client() { +} inline bool Client::is_valid() const { - return cli_ != nullptr && cli_->is_valid(); + return cli_ != nullptr && cli_->is_valid(); } -inline Result Client::Get(const char *path) { return cli_->Get(path); } +inline Result Client::Get(const char *path) { + return cli_->Get(path); +} inline Result Client::Get(const char *path, const Headers &headers) { - return cli_->Get(path, headers); + return cli_->Get(path, headers); } inline Result Client::Get(const char *path, Progress progress) { - return cli_->Get(path, progress); + return cli_->Get(path, progress); } inline Result Client::Get(const char *path, const Headers &headers, Progress progress) { - return cli_->Get(path, headers, progress); + return cli_->Get(path, headers, progress); } inline Result Client::Get(const char *path, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(content_receiver)); + return cli_->Get(path, std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(content_receiver)); + return cli_->Get(path, headers, std::move(content_receiver)); } inline Result Client::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, std::move(content_receiver), - std::move(progress)); + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, response_handler, content_receiver, progress); + return cli_->Get(path, headers, response_handler, content_receiver, progress); } -inline Result Client::Head(const char *path) { return cli_->Head(path); } +inline Result Client::Head(const char *path) { + return cli_->Head(path); +} inline Result Client::Head(const char *path, const Headers &headers) { - return cli_->Head(path, headers); + return cli_->Head(path, headers); } -inline Result Client::Post(const char *path) { return cli_->Post(path); } +inline Result Client::Post(const char *path) { + return cli_->Post(path); +} inline Result Client::Post(const char *path, const std::string &body, const char *content_type) { - return cli_->Post(path, body, content_type); + return cli_->Post(path, body, content_type); } inline Result Client::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Post(path, headers, body, content_type); + return cli_->Post(path, headers, body, content_type); } inline Result Client::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, content_length, content_provider, content_type); + return cli_->Post(path, content_length, content_provider, content_type); } inline Result Client::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, headers, content_length, content_provider, - content_type); + return cli_->Post(path, headers, content_length, content_provider, + content_type); } inline Result Client::Post(const char *path, const Params ¶ms) { - return cli_->Post(path, params); + return cli_->Post(path, params); } inline Result Client::Post(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Post(path, headers, params); + return cli_->Post(path, headers, params); } inline Result Client::Post(const char *path, const MultipartFormDataItems &items) { - return cli_->Post(path, items); + return cli_->Post(path, items); } inline Result Client::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - return cli_->Post(path, headers, items); + return cli_->Post(path, headers, items); +} +inline Result Client::Put(const char *path) { + return cli_->Put(path); } -inline Result Client::Put(const char *path) { return cli_->Put(path); } inline Result Client::Put(const char *path, const std::string &body, const char *content_type) { - return cli_->Put(path, body, content_type); + return cli_->Put(path, body, content_type); } inline Result Client::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Put(path, headers, body, content_type); + return cli_->Put(path, headers, body, content_type); } inline Result Client::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, content_length, content_provider, content_type); + return cli_->Put(path, content_length, content_provider, content_type); } inline Result Client::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, headers, content_length, content_provider, - content_type); + return cli_->Put(path, headers, content_length, content_provider, + content_type); } inline Result Client::Put(const char *path, const Params ¶ms) { - return cli_->Put(path, params); + return cli_->Put(path, params); } inline Result Client::Put(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Put(path, headers, params); + return cli_->Put(path, headers, params); } inline Result Client::Patch(const char *path, const std::string &body, const char *content_type) { - return cli_->Patch(path, body, content_type); + return cli_->Patch(path, body, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Patch(path, headers, body, content_type); + return cli_->Patch(path, headers, body, content_type); } inline Result Client::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, content_length, content_provider, content_type); + return cli_->Patch(path, content_length, content_provider, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, headers, content_length, content_provider, - content_type); + return cli_->Patch(path, headers, content_length, content_provider, + content_type); +} +inline Result Client::Delete(const char *path) { + return cli_->Delete(path); } -inline Result Client::Delete(const char *path) { return cli_->Delete(path); } inline Result Client::Delete(const char *path, const std::string &body, const char *content_type) { - return cli_->Delete(path, body, content_type); + return cli_->Delete(path, body, content_type); } inline Result Client::Delete(const char *path, const Headers &headers) { - return cli_->Delete(path, headers); + return cli_->Delete(path, headers); } inline Result Client::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Delete(path, headers, body, content_type); + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Options(const char *path) { + return cli_->Options(path); } -inline Result Client::Options(const char *path) { return cli_->Options(path); } inline Result Client::Options(const char *path, const Headers &headers) { - return cli_->Options(path, headers); + return cli_->Options(path, headers); } inline bool Client::send(const Request &req, Response &res) { - return cli_->send(req, res); + return cli_->send(req, res); } -inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } +inline size_t Client::is_socket_open() const { + return cli_->is_socket_open(); +} -inline void Client::stop() { cli_->stop(); } +inline void Client::stop() { + cli_->stop(); +} inline void Client::set_default_headers(Headers headers) { - cli_->set_default_headers(std::move(headers)); + cli_->set_default_headers(std::move(headers)); } -inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } +inline void Client::set_tcp_nodelay(bool on) { + cli_->set_tcp_nodelay(on); +} inline void Client::set_socket_options(SocketOptions socket_options) { - cli_->set_socket_options(socket_options); + cli_->set_socket_options(socket_options); } inline void Client::set_connection_timeout(time_t sec, time_t usec) { - cli_->set_connection_timeout(sec, usec); + cli_->set_connection_timeout(sec, usec); } inline void Client::set_read_timeout(time_t sec, time_t usec) { - cli_->set_read_timeout(sec, usec); + cli_->set_read_timeout(sec, usec); } inline void Client::set_write_timeout(time_t sec, time_t usec) { - cli_->set_write_timeout(sec, usec); + cli_->set_write_timeout(sec, usec); } inline void Client::set_basic_auth(const char *username, const char *password) { - cli_->set_basic_auth(username, password); + cli_->set_basic_auth(username, password); } inline void Client::set_bearer_token_auth(const char *token) { - cli_->set_bearer_token_auth(token); + cli_->set_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_digest_auth(const char *username, const char *password) { - cli_->set_digest_auth(username, password); + cli_->set_digest_auth(username, password); } #endif -inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_keep_alive(bool on) { + cli_->set_keep_alive(on); +} inline void Client::set_follow_location(bool on) { - cli_->set_follow_location(on); + cli_->set_follow_location(on); } -inline void Client::set_compress(bool on) { cli_->set_compress(on); } +inline void Client::set_compress(bool on) { + cli_->set_compress(on); +} -inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } +inline void Client::set_decompress(bool on) { + cli_->set_decompress(on); +} inline void Client::set_interface(const char *intf) { - cli_->set_interface(intf); + cli_->set_interface(intf); } inline void Client::set_proxy(const char *host, int port) { - cli_->set_proxy(host, port); + cli_->set_proxy(host, port); } inline void Client::set_proxy_basic_auth(const char *username, const char *password) { - cli_->set_proxy_basic_auth(username, password); + cli_->set_proxy_basic_auth(username, password); } inline void Client::set_proxy_bearer_token_auth(const char *token) { - cli_->set_proxy_bearer_token_auth(token); + cli_->set_proxy_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_proxy_digest_auth(const char *username, const char *password) { - cli_->set_proxy_digest_auth(username, password); + cli_->set_proxy_digest_auth(username, password); } #endif -inline void Client::set_logger(Logger logger) { cli_->set_logger(logger); } +inline void Client::set_logger(Logger logger) { + cli_->set_logger(logger); +} #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline Client &Client::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, - ca_cert_dir_path); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, + ca_cert_dir_path); + } + return *this; } inline Client &Client::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } + return *this; } inline Client &Client::enable_server_certificate_verification(bool enabled) { - if (is_ssl_) { - static_cast(*cli_).enable_server_certificate_verification( - enabled); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).enable_server_certificate_verification( + enabled); + } + return *this; } inline long Client::get_openssl_verify_result() const { - if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); - } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } inline SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } - return nullptr; + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; } #endif // ---------------------------------------------------------------------------- -} // namespace httplib +} // namespace httplib -#endif // CPPHTTPLIB_HTTPLIB_H +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/src/bb/base/rt.h b/src/bb/base/rt.h index 8d593dd3..20b0450a 100644 --- a/src/bb/base/rt.h +++ b/src/bb/base/rt.h @@ -20,15 +20,15 @@ namespace base { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // base -} // bb -} // ion +} // namespace base +} // namespace bb +} // namespace ion #define ION_REGISTER_EXTERN(NAME) static auto ion_register_extern_##NAME = ion::bb::base::RegisterExtern(#NAME, NAME); namespace ion { diff --git a/src/bb/bb.cc b/src/bb/bb.cc index fc853f08..22995a62 100644 --- a/src/bb/bb.cc +++ b/src/bb/bb.cc @@ -39,40 +39,40 @@ #include "llm/rt.h" #endif -extern "C" void register_externs(std::map& externs) { +extern "C" void register_externs(std::map &externs) { #if defined(ION_ENABLE_BB_BASE) - for (auto kv : ion::bb::base::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::base::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_DNN) - for (auto kv : ion::bb::dnn::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::dnn::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_IMAGE_IO) - for (auto kv : ion::bb::image_io::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::image_io::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_IMAGE_PROCESSING) - for (auto kv : ion::bb::image_processing::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::image_processing::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_OPENCV) - for (auto kv : ion::bb::opencv::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::opencv::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_SGM) - for (auto kv : ion::bb::sgm::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::sgm::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif #if defined(ION_ENABLE_BB_LLM) - for (auto kv : ion::bb::llm::extern_functions) { - externs.insert({kv.first, Halide::JITExtern(kv.second)}); - } + for (auto kv : ion::bb::llm::extern_functions) { + externs.insert({kv.first, Halide::JITExtern(kv.second)}); + } #endif } diff --git a/src/bb/dnn/NvInferRuntime.h b/src/bb/dnn/NvInferRuntime.h index 192d0b79..8f2cd1b8 100644 --- a/src/bb/dnn/NvInferRuntime.h +++ b/src/bb/dnn/NvInferRuntime.h @@ -25,12 +25,11 @@ #include "NvInferRuntimeCommon.h" -namespace nvinfer1 -{ +namespace nvinfer1 { -class IExecutionContext; //!< Forward declaration of IExecutionContext for use by other interfaces. -class ICudaEngine; //!< Forward declaration of ICudaENgine for use by other interfaces. -class IPluginFactory; //!< Forward declaration of IPluginFactory for use by other interfaces. +class IExecutionContext; //!< Forward declaration of IExecutionContext for use by other interfaces. +class ICudaEngine; //!< Forward declaration of ICudaENgine for use by other interfaces. +class IPluginFactory; //!< Forward declaration of IPluginFactory for use by other interfaces. //! //! \enum EngineCapability @@ -46,19 +45,17 @@ class IPluginFactory; //!< Forward declaration of IPluginFactory for use by othe //! the resulting serialized engine can be executed using NvMediaDLA's runtime APIs. See sampleNvmedia for an //! example of integrating NvMediaDLA APIs with TensorRT APIs. //! -enum class EngineCapability : int32_t -{ - kDEFAULT = 0, //!< Full capability, TensorRT mode without any restrictions using TensorRT nvinfer1 APIs. - kSAFE_GPU = 1, //!< Safety restricted capability, TensorRT flow that can only run on GPU devices via TensorRT - //!< nvinfer1::safe APIs. - kSAFE_DLA = 2, //!< Safety restricted capability, TensorRT flow that can only run on DLA devices via - //!< NvMediaDLA APIs. +enum class EngineCapability : int32_t { + kDEFAULT = 0, //!< Full capability, TensorRT mode without any restrictions using TensorRT nvinfer1 APIs. + kSAFE_GPU = 1, //!< Safety restricted capability, TensorRT flow that can only run on GPU devices via TensorRT + //!< nvinfer1::safe APIs. + kSAFE_DLA = 2, //!< Safety restricted capability, TensorRT flow that can only run on DLA devices via + //!< NvMediaDLA APIs. }; //! Maximum number of elements in EngineCapability enum. \see EngineCapability -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 3; } @@ -74,12 +71,11 @@ constexpr inline int32_t EnumMax() //! The weights are held by reference until the engine has been built. Therefore the data referenced //! by \p values field should be preserved until the build is complete. //! -class Weights -{ +class Weights { public: - DataType type; //!< The type of the weights. - const void* values; //!< The weight values, in a contiguous array. - int64_t count; //!< The number of weights in the array. + DataType type; //!< The type of the weights. + const void *values; //!< The weight values, in a contiguous array. + int64_t count; //!< The number of weights in the array. }; //! @@ -92,15 +88,15 @@ class Weights //! //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. //! -class IHostMemory -{ +class IHostMemory { public: - virtual void* data() const noexcept = 0; //!< A pointer to the raw data that is owned by the library. - virtual std::size_t size() const noexcept = 0; //!< The size in bytes of the data that was allocated. - virtual DataType type() const noexcept = 0; //!< The type of the memory that was allocated. - virtual void destroy() noexcept = 0; //!< Destroy the allocated memory. + virtual void *data() const noexcept = 0; //!< A pointer to the raw data that is owned by the library. + virtual std::size_t size() const noexcept = 0; //!< The size in bytes of the data that was allocated. + virtual DataType type() const noexcept = 0; //!< The type of the memory that was allocated. + virtual void destroy() noexcept = 0; //!< Destroy the allocated memory. protected: - virtual ~IHostMemory() {} + virtual ~IHostMemory() { + } }; //! \class IPlugin @@ -110,8 +106,7 @@ class IHostMemory //! Plugins are a mechanism for applications to implement custom layers. Each plugin is owned by the application, and its lifetime //! must span any use of it by TensorRT //! -class IPlugin -{ +class IPlugin { public: //! //! \brief Get the number of outputs from the layer. @@ -133,7 +128,7 @@ class IPlugin //! This function is called by the implementations of INetworkDefinition and IBuilder. In particular, it is called //! prior to any call to initialize(). //! - virtual Dims getOutputDimensions(int32_t index, const Dims* inputs, int32_t nbInputDims) TRTNOEXCEPT = 0; + virtual Dims getOutputDimensions(int32_t index, const Dims *inputs, int32_t nbInputDims) TRTNOEXCEPT = 0; //! //! \brief Configure the layer. @@ -153,8 +148,8 @@ class IPlugin //! //! This method is not called for PluginExt classes, configureWithFormat is called instead. //! - virtual void configure(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs, - int32_t maxBatchSize) TRTNOEXCEPT = 0; + virtual void configure(const Dims *inputDims, int32_t nbInputs, const Dims *outputDims, int32_t nbOutputs, + int32_t maxBatchSize) TRTNOEXCEPT = 0; //! //! \brief Initialize the layer for execution. This is called when the engine is created. @@ -190,8 +185,8 @@ class IPlugin //! //! \return 0 for success, else non-zero (which will cause engine termination). //! - virtual int32_t enqueue(int32_t batchSize, const void* const* inputs, void** outputs, void* workspace, - cudaStream_t stream) TRTNOEXCEPT = 0; + virtual int32_t enqueue(int32_t batchSize, const void *const *inputs, void **outputs, void *workspace, + cudaStream_t stream) TRTNOEXCEPT = 0; //! //! \brief Find the size of the serialization buffer required. @@ -207,9 +202,10 @@ class IPlugin //! //! \see getSerializationSize() //! - virtual void serialize(void* buffer) TRTNOEXCEPT = 0; + virtual void serialize(void *buffer) TRTNOEXCEPT = 0; - virtual ~IPlugin() {} + virtual ~IPlugin() { + } }; //! @@ -220,8 +216,7 @@ class IPlugin //! Plugins are a mechanism for applications to implement custom layers. Each plugin is owned by the application, and its lifetime //! must span any use of it by TensorRT. //! -class IPluginExt : public IPlugin -{ +class IPluginExt : public IPlugin { public: //! //! \brief Return the API version with which this plugin was built. @@ -229,8 +224,7 @@ class IPluginExt : public IPlugin //! Do not override this method as it is used by the TensorRT library to maintain backwards-compatibility with //! plugins. //! - virtual int32_t getTensorRTVersion() const TRTNOEXCEPT - { + virtual int32_t getTensorRTVersion() const TRTNOEXCEPT { return NV_TENSORRT_VERSION; } @@ -267,18 +261,18 @@ class IPluginExt : public IPlugin //! //! \warning DataType:kBOOL not supported. //! - virtual void configureWithFormat(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs, - DataType type, PluginFormat format, int32_t maxBatchSize) TRTNOEXCEPT = 0; + virtual void configureWithFormat(const Dims *inputDims, int32_t nbInputs, const Dims *outputDims, int32_t nbOutputs, + DataType type, PluginFormat format, int32_t maxBatchSize) TRTNOEXCEPT = 0; - virtual ~IPluginExt() {} + virtual ~IPluginExt() { + } protected: //! //! \brief Derived classes should not implement this. In a C++11 API it would be override final. //! - void configure(const Dims* /*inputDims*/, int32_t /*nbInputs*/, const Dims* /*outputDims*/, int32_t /*nbOutputs*/, - int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT - { + void configure(const Dims * /*inputDims*/, int32_t /*nbInputs*/, const Dims * /*outputDims*/, int32_t /*nbOutputs*/, + int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT { } }; @@ -292,23 +286,21 @@ class IPluginExt : public IPlugin //! //! \see IDimensionExpr, IExprBuilder //! -enum class DimensionOperation : int32_t -{ - kSUM = 0, //!< Sum of the two operands. - kPROD = 1, //!< Product of the two operands. - kMAX = 2, //!< Maximum of the two operands. - kMIN = 3, //!< Minimum of the two operands. - kSUB = 4, //!< Substract the second element from the first. - kEQUAL = 5, //!< 1 if operands are equal, 0 otherwise. - kLESS = 6, //!< 1 if first operand is less than second operand, 0 otherwise. - kFLOOR_DIV = 7, //!< Floor division of the first element by the second. - kCEIL_DIV = 8 //!< Division rounding up +enum class DimensionOperation : int32_t { + kSUM = 0, //!< Sum of the two operands. + kPROD = 1, //!< Product of the two operands. + kMAX = 2, //!< Maximum of the two operands. + kMIN = 3, //!< Minimum of the two operands. + kSUB = 4, //!< Substract the second element from the first. + kEQUAL = 5, //!< 1 if operands are equal, 0 otherwise. + kLESS = 6, //!< 1 if first operand is less than second operand, 0 otherwise. + kFLOOR_DIV = 7, //!< Floor division of the first element by the second. + kCEIL_DIV = 8 //!< Division rounding up }; //! Maximum number of elements in DimensionOperation enum. \see DimensionOperation -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 9; } @@ -322,8 +314,7 @@ constexpr inline int32_t EnumMax() //! //! \see DimensionOperation, IPluginV2DynamicExt::getOutputDimensions //! -class IDimensionExpr -{ +class IDimensionExpr { public: //! Return true if expression is a build-time constant. virtual bool isConstant() const = 0; @@ -333,7 +324,8 @@ class IDimensionExpr virtual int32_t getConstantValue() const = 0; protected: - virtual ~IDimensionExpr() {} + virtual ~IDimensionExpr() { + } }; //! @@ -353,18 +345,18 @@ class IDimensionExpr //! //! \see IDimensionExpr //! -class IExprBuilder -{ +class IExprBuilder { public: //! Return pointer to IDimensionExp for given value. - virtual const IDimensionExpr* constant(int32_t value) = 0; + virtual const IDimensionExpr *constant(int32_t value) = 0; //! Return pointer to IDimensionExp that represents the given operation applied to first and second. //! Returns nullptr if op is not a valid DimensionOperation. - virtual const IDimensionExpr* operation(DimensionOperation op, const IDimensionExpr& first, const IDimensionExpr& second) = 0; + virtual const IDimensionExpr *operation(DimensionOperation op, const IDimensionExpr &first, const IDimensionExpr &second) = 0; protected: - virtual ~IExprBuilder() {} + virtual ~IExprBuilder() { + } }; //! @@ -372,11 +364,10 @@ class IExprBuilder //! //! Analog of class Dims with expressions instead of constants for the dimensions. //! -class DimsExprs -{ +class DimsExprs { public: - int32_t nbDims; //!< The number of dimensions. - const IDimensionExpr* d[Dims::MAX_DIMS]; //!< The extent of each dimension. + int32_t nbDims; //!< The number of dimensions. + const IDimensionExpr *d[Dims::MAX_DIMS]; //!< The extent of each dimension. }; //! @@ -384,8 +375,7 @@ class DimsExprs //! //! Summarizes tensors that a plugin might see for an input or output. //! -struct DynamicPluginTensorDesc -{ +struct DynamicPluginTensorDesc { //! Information required to interpret a pointer to tensor data, except that desc.dims has -1 in place of any runtime dimension. PluginTensorDesc desc; @@ -414,10 +404,9 @@ struct DynamicPluginTensorDesc //! and the returned type is canonicalized to DataType::kFLOAT if it is DataType::kHALF or DataType:kINT8. //! Details about the floating-point precision are elicited later by method supportsFormatCombination. //! -class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext -{ +class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext { public: - IPluginV2DynamicExt* clone() const _TENSORRT_OVERRIDE TRTNOEXCEPT = 0; + IPluginV2DynamicExt *clone() const _TENSORRT_OVERRIDE TRTNOEXCEPT = 0; //! //! \brief Get expressions for computing dimensions of an output tensor from dimensions of the input tensors. @@ -444,8 +433,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! return output; //! virtual DimsExprs getOutputDimensions( - int32_t outputIndex, const DimsExprs* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) - = 0; + int32_t outputIndex, const DimsExprs *inputs, int32_t nbInputs, IExprBuilder &exprBuilder) = 0; //! //! Limit on number of format combinations accepted. @@ -485,7 +473,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! Warning: TensorRT will stop asking for formats once it finds kFORMAT_COMBINATION_LIMIT on combinations. //! virtual bool supportsFormatCombination( - int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) TRTNOEXCEPT = 0; + int32_t pos, const PluginTensorDesc *inOut, int32_t nbInputs, int32_t nbOutputs) TRTNOEXCEPT = 0; //! //! \brief Configure the layer. @@ -500,8 +488,8 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \param out The output tensors attributes that are used for configuration. //! \param nbOutputs Number of output tensors. //! - virtual void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs, - const DynamicPluginTensorDesc* out, int32_t nbOutputs) TRTNOEXCEPT = 0; + virtual void configurePlugin(const DynamicPluginTensorDesc *in, int32_t nbInputs, + const DynamicPluginTensorDesc *out, int32_t nbOutputs) TRTNOEXCEPT = 0; //! //! \brief Find the workspace size required by the layer. @@ -512,8 +500,8 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! //! \return The workspace size. //! - virtual size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, const PluginTensorDesc* outputs, - int32_t nbOutputs) const TRTNOEXCEPT = 0; + virtual size_t getWorkspaceSize(const PluginTensorDesc *inputs, int32_t nbInputs, const PluginTensorDesc *outputs, + int32_t nbOutputs) const TRTNOEXCEPT = 0; //! //! \brief Execute the layer. @@ -527,16 +515,16 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! //! \return 0 for success, else non-zero (which will cause engine termination). //! - virtual int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) TRTNOEXCEPT = 0; + virtual int32_t enqueue(const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) TRTNOEXCEPT = 0; protected: - int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE TRTNOEXCEPT - { + int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE TRTNOEXCEPT { return (static_cast(PluginVersion::kV2_DYNAMICEXT) << 24 | (NV_TENSORRT_VERSION & 0xFFFFFF)); } - virtual ~IPluginV2DynamicExt() {} + virtual ~IPluginV2DynamicExt() { + } // Rest of the methods below are obsolete inherited methods, and marked final when using a C++11 compiler. // Derived classes should not override them. @@ -550,8 +538,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! TRT_DEPRECATED Dims getOutputDimensions( - int32_t /*index*/, const Dims* /*inputs*/, int32_t /*nbInputDims*/) _TENSORRT_FINAL TRTNOEXCEPT - { + int32_t /*index*/, const Dims * /*inputs*/, int32_t /*nbInputDims*/) _TENSORRT_FINAL TRTNOEXCEPT { return Dims{-1, {}, {}}; } @@ -563,9 +550,8 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - bool isOutputBroadcastAcrossBatch(int32_t /*outputIndex*/, const bool* /*inputIsBroadcasted*/, - int32_t /*nbInputs*/) const _TENSORRT_FINAL TRTNOEXCEPT - { + bool isOutputBroadcastAcrossBatch(int32_t /*outputIndex*/, const bool * /*inputIsBroadcasted*/, + int32_t /*nbInputs*/) const _TENSORRT_FINAL TRTNOEXCEPT { return false; } @@ -577,8 +563,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - bool canBroadcastInputAcrossBatch(int32_t /*inputIndex*/) const _TENSORRT_FINAL TRTNOEXCEPT - { + bool canBroadcastInputAcrossBatch(int32_t /*inputIndex*/) const _TENSORRT_FINAL TRTNOEXCEPT { return true; } @@ -593,8 +578,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - bool supportsFormat(DataType /*type*/, PluginFormat /*format*/) const _TENSORRT_FINAL TRTNOEXCEPT - { + bool supportsFormat(DataType /*type*/, PluginFormat /*format*/) const _TENSORRT_FINAL TRTNOEXCEPT { return false; } @@ -610,11 +594,10 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - void configurePlugin(const Dims* /*inputDims*/, int32_t /*nbInputs*/, const Dims* /*outputDims*/, - int32_t /*nbOutputs*/, const DataType* /*inputTypes*/, const DataType* /*outputTypes*/, - const bool* /*inputIsBroadcast*/, const bool* /*outputIsBroadcast*/, PluginFormat /*floatFormat*/, - int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT - { + void configurePlugin(const Dims * /*inputDims*/, int32_t /*nbInputs*/, const Dims * /*outputDims*/, + int32_t /*nbOutputs*/, const DataType * /*inputTypes*/, const DataType * /*outputTypes*/, + const bool * /*inputIsBroadcast*/, const bool * /*outputIsBroadcast*/, PluginFormat /*floatFormat*/, + int32_t /*maxBatchSize*/) _TENSORRT_FINAL TRTNOEXCEPT { } //! @@ -629,8 +612,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - size_t getWorkspaceSize(int32_t /*maxBatchSize*/) const _TENSORRT_FINAL TRTNOEXCEPT - { + size_t getWorkspaceSize(int32_t /*maxBatchSize*/) const _TENSORRT_FINAL TRTNOEXCEPT { return 0; } @@ -645,9 +627,8 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - int32_t enqueue(int32_t /*batchSize*/, const void* const* /*inputs*/, void** /*outputs*/, void* /*workspace*/, - cudaStream_t /*stream*/) _TENSORRT_FINAL TRTNOEXCEPT - { + int32_t enqueue(int32_t /*batchSize*/, const void *const * /*inputs*/, void ** /*outputs*/, void * /*workspace*/, + cudaStream_t /*stream*/) _TENSORRT_FINAL TRTNOEXCEPT { return 1; } }; @@ -662,8 +643,7 @@ class IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext //! //! The profiler will only be called after execution is complete. It has a small impact on execution time. //! -class IProfiler -{ +class IProfiler { public: //! //! \brief Layer time reporting callback. @@ -671,9 +651,10 @@ class IProfiler //! \param layerName The name of the layer, set when constructing the network definition. //! \param ms The time in milliseconds to execute the layer. //! - virtual void reportLayerTime(const char* layerName, float ms) TRTNOEXCEPT = 0; + virtual void reportLayerTime(const char *layerName, float ms) TRTNOEXCEPT = 0; - virtual ~IProfiler() {} + virtual ~IProfiler() { + } }; //! @@ -682,19 +663,17 @@ class IProfiler //! //! The power weights of an IScaleLayer are omitted. Refitting those is not supported. //! -enum class WeightsRole : int32_t -{ - kKERNEL = 0, //!< kernel for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer - kBIAS = 1, //!< bias for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer - kSHIFT = 2, //!< shift part of IScaleLayer - kSCALE = 3, //!< scale part of IScaleLayer - kCONSTANT = 4, //!< weights for IConstantLayer +enum class WeightsRole : int32_t { + kKERNEL = 0, //!< kernel for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer + kBIAS = 1, //!< bias for IConvolutionLayer, IDeconvolutionLayer, or IFullyConnectedLayer + kSHIFT = 2, //!< shift part of IScaleLayer + kSCALE = 3, //!< scale part of IScaleLayer + kCONSTANT = 4, //!< weights for IConstantLayer }; //! Maximum number of elements in WeightsRole enum. \see WeightsRole -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 5; } @@ -703,16 +682,14 @@ constexpr inline int32_t EnumMax() //! \brief The device that this layer/network will execute on. //! //! -enum class DeviceType : int32_t -{ - kGPU, //!< GPU Device - kDLA, //!< DLA Core +enum class DeviceType : int32_t { + kGPU, //!< GPU Device + kDLA, //!< DLA Core }; //! Maximum number of elements in DeviceType enum. \see DeviceType -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 2; } @@ -723,8 +700,7 @@ constexpr inline int32_t EnumMax() //! //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. //! -class IRuntime -{ +class IRuntime { public: //! //! \brief Deserialize an engine from a stream. @@ -735,7 +711,7 @@ class IRuntime //! //! \return The engine, or nullptr if it could not be deserialized. //! - virtual nvinfer1::ICudaEngine* deserializeCudaEngine(const void* blob, std::size_t size, IPluginFactory* pluginFactory) noexcept = 0; + virtual nvinfer1::ICudaEngine *deserializeCudaEngine(const void *blob, std::size_t size, IPluginFactory *pluginFactory) noexcept = 0; //! //! \brief Set the DLA core that the deserialized engine must execute on. @@ -765,7 +741,8 @@ class IRuntime virtual void destroy() noexcept = 0; protected: - virtual ~IRuntime() {} + virtual ~IRuntime() { + } public: //! @@ -776,7 +753,7 @@ class IRuntime //! //! If nullptr is passed, the default allocator will be used. //! - virtual void setGpuAllocator(IGpuAllocator* allocator) noexcept = 0; + virtual void setGpuAllocator(IGpuAllocator *allocator) noexcept = 0; //! //! \brief Set the ErrorRecorder for this interface @@ -790,7 +767,7 @@ class IRuntime // //! \see getErrorRecorder //! - virtual void setErrorRecorder(IErrorRecorder* recorder) noexcept = 0; + virtual void setErrorRecorder(IErrorRecorder *recorder) noexcept = 0; //! //! \brief get the ErrorRecorder assigned to this interface. @@ -802,7 +779,7 @@ class IRuntime //! //! \see setErrorRecorder //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; + virtual IErrorRecorder *getErrorRecorder() const noexcept = 0; //! //! \brief Deserialize an engine from a stream when plugin factory is not used. @@ -812,8 +789,7 @@ class IRuntime //! //! \return The engine, or nullptr if it could not be deserialized. //! - nvinfer1::ICudaEngine* deserializeCudaEngine(const void* blob, std::size_t size) noexcept - { + nvinfer1::ICudaEngine *deserializeCudaEngine(const void *blob, std::size_t size) noexcept { return deserializeCudaEngine(blob, size, nullptr); } }; @@ -825,8 +801,7 @@ class IRuntime //! //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. //! -class IRefitter -{ +class IRefitter { public: //! //! \brief Specify new weights for a layer of given name. @@ -838,7 +813,7 @@ class IRefitter //! * The number of weights is inconsistent with the layer’s original specification. //! //! Modifying the weights before method refit() completes will result in undefined behavior. - virtual bool setWeights(const char* layerName, WeightsRole role, Weights weights) TRTNOEXCEPT = 0; + virtual bool setWeights(const char *layerName, WeightsRole role, Weights weights) TRTNOEXCEPT = 0; //! //! \brief Updates associated engine. Return true if successful. @@ -863,7 +838,7 @@ class IRefitter //! If layerNames!=nullptr, each written pointer points to a string owned by //! the engine being refitted, and becomes invalid when the engine is destroyed. //! - virtual int32_t getMissing(int32_t size, const char** layerNames, WeightsRole* roles) TRTNOEXCEPT = 0; + virtual int32_t getMissing(int32_t size, const char **layerNames, WeightsRole *roles) TRTNOEXCEPT = 0; //! //! \brief Get description of all weights that could be refit. @@ -877,12 +852,13 @@ class IRefitter //! If layerNames!=nullptr, each written pointer points to a string owned by //! the engine being refitted, and becomes invalid when the engine is destroyed. //! - virtual int32_t getAll(int32_t size, const char** layerNames, WeightsRole* roles) TRTNOEXCEPT = 0; + virtual int32_t getAll(int32_t size, const char **layerNames, WeightsRole *roles) TRTNOEXCEPT = 0; virtual void destroy() TRTNOEXCEPT = 0; protected: - virtual ~IRefitter() {} + virtual ~IRefitter() { + } public: //! @@ -897,7 +873,7 @@ class IRefitter //! Returns false if there is no Int8 engine tensor derived from //! a network tensor of that name. If successful, then getMissing //! may report that some weights need to be supplied. - virtual bool setDynamicRange(const char* tensorName, float min, float max) TRTNOEXCEPT = 0; + virtual bool setDynamicRange(const char *tensorName, float min, float max) TRTNOEXCEPT = 0; //! //! \brief Get minimum of dynamic range. @@ -906,7 +882,7 @@ class IRefitter //! //! If the dynamic range was never set, returns the minimum computed during calibration. //! - virtual float getDynamicRangeMin(const char* tensorName) const TRTNOEXCEPT = 0; + virtual float getDynamicRangeMin(const char *tensorName) const TRTNOEXCEPT = 0; //! //! \brief Get maximum of dynamic range. @@ -915,7 +891,7 @@ class IRefitter //! //! If the dynamic range was never set, returns the maximum computed during calibration. //! - virtual float getDynamicRangeMax(const char* tensorName) const TRTNOEXCEPT = 0; + virtual float getDynamicRangeMax(const char *tensorName) const TRTNOEXCEPT = 0; //! //! \brief Get names of all tensors that have refittable dynamic ranges. @@ -928,7 +904,7 @@ class IRefitter //! If tensorNames!=nullptr, each written pointer points to a string owned by //! the engine being refitted, and becomes invalid when the engine is destroyed. //! - virtual int32_t getTensorsWithDynamicRange(int32_t size, const char** tensorNames) const TRTNOEXCEPT = 0; + virtual int32_t getTensorsWithDynamicRange(int32_t size, const char **tensorNames) const TRTNOEXCEPT = 0; //! //! \brief Set the ErrorRecorder for this interface @@ -942,7 +918,7 @@ class IRefitter // //! \see getErrorRecorder //! - virtual void setErrorRecorder(IErrorRecorder* recorder) TRTNOEXCEPT = 0; + virtual void setErrorRecorder(IErrorRecorder *recorder) TRTNOEXCEPT = 0; //! //! \brief get the ErrorRecorder assigned to this interface. @@ -954,7 +930,7 @@ class IRefitter //! //! \see setErrorRecorder //! - virtual IErrorRecorder* getErrorRecorder() const TRTNOEXCEPT = 0; + virtual IErrorRecorder *getErrorRecorder() const TRTNOEXCEPT = 0; }; //! @@ -963,8 +939,7 @@ class IRefitter //! \brief Plugin factory for deserialization. //! //! This Interface is guaranteed not to change for the same major version of TensorRT. -class IPluginFactory -{ +class IPluginFactory { public: //! //! \brief Create a plugin from serialized data. @@ -980,9 +955,10 @@ class IPluginFactory //! //! \see IPlugin::serialize() //! - virtual IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) TRTNOEXCEPT = 0; + virtual IPlugin *createPlugin(const char *layerName, const void *serialData, size_t serialLength) TRTNOEXCEPT = 0; - virtual ~IPluginFactory() {} + virtual ~IPluginFactory() { + } }; //! @@ -995,17 +971,15 @@ class IPluginFactory //! //! \see IOptimizationProfile::setDimensions(), IOptimizationProfile::setShapeValues() //! -enum class OptProfileSelector : int32_t -{ - kMIN = 0, //!< This is used to set or get the minimum permitted value for dynamic dimensions etc. - kOPT = 1, //!< This is used to set or get the value that is used in the optimization (kernel selection). - kMAX = 2 //!< This is used to set or get the maximum permitted value for dynamic dimensions etc. +enum class OptProfileSelector : int32_t { + kMIN = 0, //!< This is used to set or get the minimum permitted value for dynamic dimensions etc. + kOPT = 1, //!< This is used to set or get the value that is used in the optimization (kernel selection). + kMAX = 2 //!< This is used to set or get the maximum permitted value for dynamic dimensions etc. }; //!< Number of different values of OptProfileSelector enum. \see OptProfileSelector -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 3; } @@ -1031,8 +1005,7 @@ constexpr inline int32_t EnumMax() //! //! \see IBuilderConfig::addOptimizationProfile() //! -class IOptimizationProfile -{ +class IOptimizationProfile { public: //! //! \brief Set the minimum / optimum / maximum dimensions for a dynamic input tensor. @@ -1059,14 +1032,14 @@ class IOptimizationProfile //! //! \warning If run on DLA, minimum, optimum, and maximum dimensions must to be the same. //! - virtual bool setDimensions(const char* inputName, OptProfileSelector select, Dims dims) noexcept = 0; + virtual bool setDimensions(const char *inputName, OptProfileSelector select, Dims dims) noexcept = 0; //! //! \brief Get the minimum / optimum / maximum dimensions for a dynamic input tensor. //! //! If the dimensions have not been previously set via setDimensions(), return an invalid Dims with nbDims == -1. //! - virtual Dims getDimensions(const char* inputName, OptProfileSelector select) const noexcept = 0; + virtual Dims getDimensions(const char *inputName, OptProfileSelector select) const noexcept = 0; //! //! \brief Set the minimum / optimum / maximum values for an input shape tensor. @@ -1092,8 +1065,7 @@ class IOptimizationProfile //! \warning If run on DLA, minimum, optimum, and maximum shape values must to be the same. //! virtual bool setShapeValues( - const char* inputName, OptProfileSelector select, const int32_t* values, int32_t nbValues) noexcept - = 0; + const char *inputName, OptProfileSelector select, const int32_t *values, int32_t nbValues) noexcept = 0; //! //! \brief Get the number of values for an input shape tensor. @@ -1101,14 +1073,14 @@ class IOptimizationProfile //! This will return the number of shape values if setShapeValues() has been called before for this input tensor. //! Otherwise, return -1. //! - virtual int32_t getNbShapeValues(const char* inputName) const noexcept = 0; + virtual int32_t getNbShapeValues(const char *inputName) const noexcept = 0; //! //! \brief Get the minimum / optimum / maximum values for an input shape tensor. //! //! If the shape values have not been set previously with setShapeValues(), this returns nullptr. //! - virtual const int32_t* getShapeValues(const char* inputName, OptProfileSelector select) const noexcept = 0; + virtual const int32_t *getShapeValues(const char *inputName, OptProfileSelector select) const noexcept = 0; //! //! \brief Set a target for extra GPU memory that may be used by this profile. @@ -1154,8 +1126,7 @@ class IOptimizationProfile //! //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. //! -class ICudaEngine -{ +class ICudaEngine { public: //! //! \brief Get the number of binding indices. @@ -1186,7 +1157,7 @@ class ICudaEngine //! //! \see getNbBindings() getBindingName() //! - virtual int32_t getBindingIndex(const char* name) const noexcept = 0; + virtual int32_t getBindingIndex(const char *name) const noexcept = 0; //! //! \brief Retrieve the name corresponding to a binding index. @@ -1203,7 +1174,7 @@ class ICudaEngine //! //! \see getBindingIndex() //! - virtual const char* getBindingName(int32_t bindingIndex) const noexcept = 0; + virtual const char *getBindingName(int32_t bindingIndex) const noexcept = 0; //! //! \brief Determine whether a binding is an input binding. @@ -1290,7 +1261,7 @@ class ICudaEngine //! //! \see IRuntime::deserializeCudaEngine() safe::IRuntime::deserializeCudaEngine() //! - virtual IHostMemory* serialize() const noexcept = 0; + virtual IHostMemory *serialize() const noexcept = 0; //! //! \brief Create an execution context. @@ -1302,7 +1273,7 @@ class ICudaEngine //! \see IExecutionContext. //! \see IExecutionContext::setOptimizationProfile() //! - virtual IExecutionContext* createExecutionContext() noexcept = 0; + virtual IExecutionContext *createExecutionContext() noexcept = 0; //! //! \brief Destroy this object; @@ -1322,7 +1293,8 @@ class ICudaEngine virtual TensorLocation getLocation(int32_t bindingIndex) const noexcept = 0; protected: - virtual ~ICudaEngine() {} + virtual ~ICudaEngine() { + } public: //! \brief create an execution context without any device memory allocated @@ -1331,7 +1303,7 @@ class ICudaEngine //! //! \see getDeviceMemorySize() IExecutionContext::setDeviceMemory() //! - virtual IExecutionContext* createExecutionContextWithoutDeviceMemory() noexcept = 0; + virtual IExecutionContext *createExecutionContextWithoutDeviceMemory() noexcept = 0; //! //! \brief Return the amount of device memory required by an execution context. @@ -1390,7 +1362,7 @@ class ICudaEngine //! //! \param bindingIndex The binding Index. //! - virtual const char* getBindingFormatDesc(int32_t bindingIndex) const noexcept = 0; + virtual const char *getBindingFormatDesc(int32_t bindingIndex) const noexcept = 0; //! //! \brief Return the dimension index that the buffer is vectorized. @@ -1411,7 +1383,7 @@ class ICudaEngine //! //! \return A zero delimited C-style string representing the name of the network. //! - virtual const char* getName() const noexcept = 0; + virtual const char *getName() const noexcept = 0; //! //! \brief Get the number of optimization profiles defined for this engine. @@ -1444,8 +1416,7 @@ class ICudaEngine //! Otherwise the bindingIndex is considered invalid. //! virtual Dims getProfileDimensions(int32_t bindingIndex, int32_t profileIndex, OptProfileSelector select) const - noexcept - = 0; + noexcept = 0; //! //! \brief Get minimum / optimum / maximum values for an input shape binding under an optimization profile. @@ -1468,9 +1439,8 @@ class ICudaEngine //! //! \see ICudaEngine::getProfileDimensions //! - virtual const int32_t* getProfileShapeValues( - int32_t profileIndex, int32_t inputIndex, OptProfileSelector select) const noexcept - = 0; + virtual const int32_t *getProfileShapeValues( + int32_t profileIndex, int32_t inputIndex, OptProfileSelector select) const noexcept = 0; //! //! \brief True if tensor is required as input for shape calculations or output from them. @@ -1539,7 +1509,7 @@ class ICudaEngine // //! \see getErrorRecorder //! - virtual void setErrorRecorder(IErrorRecorder* recorder) noexcept = 0; + virtual void setErrorRecorder(IErrorRecorder *recorder) noexcept = 0; //! //! \brief get the ErrorRecorder assigned to this interface. @@ -1551,7 +1521,7 @@ class ICudaEngine //! //! \see setErrorRecorder //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; + virtual IErrorRecorder *getErrorRecorder() const noexcept = 0; //! //! \brief Query whether the engine was built with an implicit batch dimension. @@ -1580,8 +1550,7 @@ class ICudaEngine //! dynamic shapes, each execution context in concurrent use must use a separate optimization profile. //! //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -class IExecutionContext -{ +class IExecutionContext { public: //! //! \brief Synchronously execute inference on a batch. @@ -1594,7 +1563,7 @@ class IExecutionContext //! //! \see ICudaEngine::getBindingIndex() ICudaEngine::getMaxBatchSize() //! - virtual bool execute(int32_t batchSize, void** bindings) noexcept = 0; + virtual bool execute(int32_t batchSize, void **bindings) noexcept = 0; //! //! \brief Asynchronously execute inference on a batch. @@ -1610,8 +1579,7 @@ class IExecutionContext //! //! \see ICudaEngine::getBindingIndex() ICudaEngine::getMaxBatchSize() //! - virtual bool enqueue(int32_t batchSize, void** bindings, cudaStream_t stream, cudaEvent_t* inputConsumed) noexcept - = 0; + virtual bool enqueue(int32_t batchSize, void **bindings, cudaStream_t stream, cudaEvent_t *inputConsumed) noexcept = 0; //! //! \brief Set the debug sync flag. @@ -1635,21 +1603,21 @@ class IExecutionContext //! //! \see IProfiler getProfiler() //! - virtual void setProfiler(IProfiler*) noexcept = 0; + virtual void setProfiler(IProfiler *) noexcept = 0; //! //! \brief Get the profiler. //! //! \see IProfiler setProfiler() //! - virtual IProfiler* getProfiler() const noexcept = 0; + virtual IProfiler *getProfiler() const noexcept = 0; //! //! \brief Get the associated engine. //! //! \see ICudaEngine //! - virtual const ICudaEngine& getEngine() const noexcept = 0; + virtual const ICudaEngine &getEngine() const noexcept = 0; //! //! \brief Destroy this object. @@ -1657,7 +1625,8 @@ class IExecutionContext virtual void destroy() noexcept = 0; protected: - virtual ~IExecutionContext() noexcept {} + virtual ~IExecutionContext() noexcept { + } public: //! @@ -1667,14 +1636,14 @@ class IExecutionContext //! //! \see getName() //! - virtual void setName(const char* name) noexcept = 0; + virtual void setName(const char *name) noexcept = 0; //! //! \brief Return the name of the execution context. //! //! \see setName() //! - virtual const char* getName() const noexcept = 0; + virtual const char *getName() const noexcept = 0; //! //! \brief Set the device memory for use by this execution context. @@ -1687,7 +1656,7 @@ class IExecutionContext //! //! \see ICudaEngine::getDeviceMemorySize() ICudaEngine::createExecutionContextWithoutDeviceMemory() //! - virtual void setDeviceMemory(void* memory) noexcept = 0; + virtual void setDeviceMemory(void *memory) noexcept = 0; //! //! \brief Return the strides of the buffer for the given binding. @@ -1817,7 +1786,7 @@ class IExecutionContext //! This method will fail unless a valid optimization profile is defined for the current //! execution context (getOptimizationProfile() must not be -1). //! - virtual bool setInputShapeBinding(int32_t bindingIndex, const int32_t* data) noexcept = 0; + virtual bool setInputShapeBinding(int32_t bindingIndex, const int32_t *data) noexcept = 0; //! //! \brief Get values of an input tensor required for shape calculations or an output tensor produced by shape @@ -1836,7 +1805,7 @@ class IExecutionContext //! //! \see isShapeBinding(bindingIndex) //! - virtual bool getShapeBinding(int32_t bindingIndex, int32_t* data) const noexcept = 0; + virtual bool getShapeBinding(int32_t bindingIndex, int32_t *data) const noexcept = 0; //! //! \brief Whether all dynamic dimensions of input tensors have been specified @@ -1873,7 +1842,7 @@ class IExecutionContext // //! \see getErrorRecorder //! - virtual void setErrorRecorder(IErrorRecorder* recorder) noexcept = 0; + virtual void setErrorRecorder(IErrorRecorder *recorder) noexcept = 0; //! //! \brief get the ErrorRecorder assigned to this interface. @@ -1885,7 +1854,7 @@ class IExecutionContext //! //! \see setErrorRecorder //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; + virtual IErrorRecorder *getErrorRecorder() const noexcept = 0; //! //! \brief Synchronously execute inference a network. @@ -1899,7 +1868,7 @@ class IExecutionContext //! //! \see ICudaEngine::getBindingIndex() ICudaEngine::getMaxBatchSize() //! - virtual bool executeV2(void** bindings) noexcept = 0; + virtual bool executeV2(void **bindings) noexcept = 0; //! //! \brief Asynchronously execute inference. @@ -1920,7 +1889,7 @@ class IExecutionContext //! used, the first enqueueV2() call after a setInputShapeBinding() call will cause failure in stream capture //! due to resource allocation. Please call enqueueV2() once before capturing the graph. //! - virtual bool enqueueV2(void** bindings, cudaStream_t stream, cudaEvent_t* inputConsumed) noexcept = 0; + virtual bool enqueueV2(void **bindings, cudaStream_t stream, cudaEvent_t *inputConsumed) noexcept = 0; //! //! \brief Select an optimization profile for the current context with async @@ -1963,8 +1932,8 @@ class IExecutionContext //! \see ICudaEngine::getNbOptimizationProfiles() //! IExecutionContext::setOptimizationProfile() virtual bool setOptimizationProfileAsync(int32_t profileIndex, cudaStream_t stream) noexcept = 0; -}; // class IExecutionContext -} // namespace nvinfer1 +}; // class IExecutionContext +} // namespace nvinfer1 //! //! Internal C entry point for creating IRuntime. @@ -1978,18 +1947,16 @@ class IExecutionContext //! // extern "C" TENSORRTAPI void* createInferRefitter_INTERNAL(void* engine, void* logger, int32_t version); -namespace nvinfer1 -{ -namespace // unnamed namespace avoids linkage surprises when linking objects built with different versions of this header. +namespace nvinfer1 { +namespace // unnamed namespace avoids linkage surprises when linking objects built with different versions of this header. { //! //! \brief Create an instance of an IRuntime class. //! //! This class is the logging class for the runtime. //! -inline IRuntime* createInferRuntime(ILogger& logger) -{ - return static_cast(createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); +inline IRuntime *createInferRuntime(ILogger &logger) { + return static_cast(createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); } //! @@ -1997,11 +1964,10 @@ inline IRuntime* createInferRuntime(ILogger& logger) //! //! This class is the logging class for the refitter. //! -inline IRefitter* createInferRefitter(ICudaEngine& engine, ILogger& logger) -{ - return static_cast(createInferRefitter_INTERNAL(&engine, &logger, NV_TENSORRT_VERSION)); -} -} +inline IRefitter *createInferRefitter(ICudaEngine &engine, ILogger &logger) { + return static_cast(createInferRefitter_INTERNAL(&engine, &logger, NV_TENSORRT_VERSION)); } +} // namespace +} // namespace nvinfer1 -#endif // NV_INFER_RUNTIME_H +#endif // NV_INFER_RUNTIME_H diff --git a/src/bb/dnn/NvInferRuntimeCommon.h b/src/bb/dnn/NvInferRuntimeCommon.h index e6c6dd00..8882527a 100644 --- a/src/bb/dnn/NvInferRuntimeCommon.h +++ b/src/bb/dnn/NvInferRuntimeCommon.h @@ -79,25 +79,23 @@ struct cublasContext; struct cudnnContext; -typedef struct CUstream_st* cudaStream_t; //!< Forward declaration of cudaStream_t. -typedef struct CUevent_st* cudaEvent_t; //!< Forward declaration of cudaEvent_t. +typedef struct CUstream_st *cudaStream_t; //!< Forward declaration of cudaStream_t. +typedef struct CUevent_st *cudaEvent_t; //!< Forward declaration of cudaEvent_t. -static const int32_t NV_TENSORRT_VERSION - = (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH; // major, minor, patch +static const int32_t NV_TENSORRT_VERSION = (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH; // major, minor, patch //! //! \namespace nvinfer1 //! //! \brief The TensorRT API version 1 namespace. //! -namespace nvinfer1 -{ +namespace nvinfer1 { -class IErrorRecorder; //!< Forward declare IErrorRecorder for use in other interfaces. -class IGpuAllocator; //!< Forward declare IGpuAllocator for use in other interfaces. +class IErrorRecorder; //!< Forward declare IErrorRecorder for use in other interfaces. +class IGpuAllocator; //!< Forward declare IGpuAllocator for use in other interfaces. //! Maximum number of elements in an enumeration type. -template +template constexpr inline int32_t EnumMax(); //! @@ -105,26 +103,24 @@ constexpr inline int32_t EnumMax(); //! //! \brief Enumerates the types of activation to perform in an activation layer. //! -enum class ActivationType : int32_t -{ - kRELU = 0, //!< Rectified linear activation. - kSIGMOID = 1, //!< Sigmoid activation. - kTANH = 2, //!< TanH activation. - kLEAKY_RELU = 3, //!< LeakyRelu activation: x>=0 ? x : alpha * x. - kELU = 4, //!< Elu activation: x>=0 ? x : alpha * (exp(x) - 1). - kSELU = 5, //!< Selu activation: x>0 ? beta * x : beta * (alpha*exp(x) - alpha) - kSOFTSIGN = 6, //!< Softsign activation: x / (1+|x|) - kSOFTPLUS = 7, //!< Parametric softplus activation: alpha*log(exp(beta*x)+1) - kCLIP = 8, //!< Clip activation: max(alpha, min(beta, x)) - kHARD_SIGMOID = 9, //!< Hard sigmoid activation: max(0, min(1, alpha*x+beta)) - kSCALED_TANH = 10, //!< Scaled tanh activation: alpha*tanh(beta*x) - kTHRESHOLDED_RELU = 11 //!< Thresholded ReLU activation: x>alpha ? x : 0 +enum class ActivationType : int32_t { + kRELU = 0, //!< Rectified linear activation. + kSIGMOID = 1, //!< Sigmoid activation. + kTANH = 2, //!< TanH activation. + kLEAKY_RELU = 3, //!< LeakyRelu activation: x>=0 ? x : alpha * x. + kELU = 4, //!< Elu activation: x>=0 ? x : alpha * (exp(x) - 1). + kSELU = 5, //!< Selu activation: x>0 ? beta * x : beta * (alpha*exp(x) - alpha) + kSOFTSIGN = 6, //!< Softsign activation: x / (1+|x|) + kSOFTPLUS = 7, //!< Parametric softplus activation: alpha*log(exp(beta*x)+1) + kCLIP = 8, //!< Clip activation: max(alpha, min(beta, x)) + kHARD_SIGMOID = 9, //!< Hard sigmoid activation: max(0, min(1, alpha*x+beta)) + kSCALED_TANH = 10, //!< Scaled tanh activation: alpha*tanh(beta*x) + kTHRESHOLDED_RELU = 11 //!< Thresholded ReLU activation: x>alpha ? x : 0 }; //! Maximum number of elements in ActivationType enum. \see ActivationType -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 12; } @@ -133,8 +129,7 @@ constexpr inline int32_t EnumMax() //! //! \brief The type of weights and tensors. //! -enum class DataType : int32_t -{ +enum class DataType : int32_t { //! 32-bit floating point format. kFLOAT = 0, @@ -152,9 +147,8 @@ enum class DataType : int32_t }; //! Maximum number of elements in DataType enum. \see DataType -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 5; } @@ -162,18 +156,16 @@ constexpr inline int32_t EnumMax() //! \enum DimensionType //! \brief The type of data encoded across this dimension. //! -enum class DimensionType : int32_t -{ - kSPATIAL = 0, //!< Elements correspond to different spatial data. - kCHANNEL = 1, //!< Elements correspond to different channels. - kINDEX = 2, //!< Elements correspond to different batch index. - kSEQUENCE = 3 //!< Elements correspond to different sequence values. +enum class DimensionType : int32_t { + kSPATIAL = 0, //!< Elements correspond to different spatial data. + kCHANNEL = 1, //!< Elements correspond to different channels. + kINDEX = 2, //!< Elements correspond to different batch index. + kSEQUENCE = 3 //!< Elements correspond to different sequence values. }; //! Maximum number of elements in DimensionType enum. \see DimensionType -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 4; } @@ -191,14 +183,13 @@ constexpr inline int32_t EnumMax() //! TensorRT can also return an "unknown rank" dims structure. This structure is represented by nbDims == -1 //! and d[i] == -1 for all d. //! -class Dims -{ +class Dims { public: - static const int32_t MAX_DIMS = 8; //!< The maximum number of dimensions supported for a tensor. - int32_t nbDims; //!< The number of dimensions. - int32_t d[MAX_DIMS]; //!< The extent of each dimension. - TRT_DEPRECATED DimensionType type[MAX_DIMS]; //!< The type of each dimension, provided for backwards compatibility - //!< and will be removed in TensorRT 8.0. + static const int32_t MAX_DIMS = 8; //!< The maximum number of dimensions supported for a tensor. + int32_t nbDims; //!< The number of dimensions. + int32_t d[MAX_DIMS]; //!< The extent of each dimension. + TRT_DEPRECATED DimensionType type[MAX_DIMS]; //!< The type of each dimension, provided for backwards compatibility + //!< and will be removed in TensorRT 8.0. }; //! @@ -222,15 +213,14 @@ typedef uint32_t TensorFormats; //! For more information about data formats, see the topic "Data Format Description" located in the //! TensorRT Developer Guide (https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html). //! -enum class TensorFormat : int32_t -{ +enum class TensorFormat : int32_t { //! Row major linear format. //! For a tensor with dimensions {N, C, H, W} or {numbers, channels, //! columns, rows}, the dimensional index corresponds to {3, 2, 1, 0} //! and thus the order is W minor. kLINEAR = 0, - kNCHW TRT_DEPRECATED_ENUM = kLINEAR, //!< Deprecated name of kLINEAR, provided for backwards compatibility and will - //!< be removed in TensorRT 8.0. + kNCHW TRT_DEPRECATED_ENUM = kLINEAR, //!< Deprecated name of kLINEAR, provided for backwards compatibility and will + //!< be removed in TensorRT 8.0. //! Two wide channel vectorized row major format. This format is bound to //! FP16. It is only available for dimensions >= 3. @@ -239,8 +229,8 @@ enum class TensorFormat : int32_t //! [N][(C+1)/2][H][W][2], with the tensor coordinates (n, c, h, w) //! mapping to array subscript [n][c/2][h][w][c%2]. kCHW2 = 1, - kNC2HW2 TRT_DEPRECATED_ENUM = kCHW2, //!< Deprecated name of kCHW2, provided for backwards compatibility and will - //!< be removed in TensorRT 8.0. + kNC2HW2 TRT_DEPRECATED_ENUM = kCHW2, //!< Deprecated name of kCHW2, provided for backwards compatibility and will + //!< be removed in TensorRT 8.0. //! Eight channel format where C is padded to a multiple of 8. This format //! is bound to FP16. It is only available for dimensions >= 3. @@ -249,8 +239,8 @@ enum class TensorFormat : int32_t //! [N][H][W][(C+7)/8*8], with the tensor coordinates (n, h, w, c) //! mapping to array subscript [n][h][w][c]. kHWC8 = 2, - kNHWC8 TRT_DEPRECATED_ENUM = kHWC8, //!< Deprecated name of kHWC8, provided for backwards compatibility and will - //!< be removed in TensorRT 8.0. + kNHWC8 TRT_DEPRECATED_ENUM = kHWC8, //!< Deprecated name of kHWC8, provided for backwards compatibility and will + //!< be removed in TensorRT 8.0. //! Four wide channel vectorized row major format. This format is bound to //! INT8 or FP16. It is only available for dimensions >= 3. @@ -317,9 +307,8 @@ enum class TensorFormat : int32_t using PluginFormat = TensorFormat; //! Maximum number of elements in TensorFormat enum. \see TensorFormat -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 9; } @@ -333,10 +322,9 @@ constexpr inline int32_t EnumMax() //! \see IPluginV2IOExt::supportsFormat //! \see IPluginV2IOExt::configurePlugin //! -struct PluginTensorDesc -{ +struct PluginTensorDesc { Dims dims; - DataType type; //!< \warning DataType:kBOOL not supported. + DataType type; //!< \warning DataType:kBOOL not supported. TensorFormat format; float scale; }; @@ -347,12 +335,11 @@ struct PluginTensorDesc //! //! Tag for plug-in versions. Used in upper byte of getTensorRTVersion(). //! -enum class PluginVersion : uint8_t -{ - kV2 = 0, //! IPluginV2 - kV2_EXT = 1, //! IPluginV2Ext - kV2_IOEXT = 2, //! IPluginV2IOExt - kV2_DYNAMICEXT = 3, //! IPluginV2DynamicExt +enum class PluginVersion : uint8_t { + kV2 = 0, //! IPluginV2 + kV2_EXT = 1, //! IPluginV2Ext + kV2_IOEXT = 2, //! IPluginV2IOExt + kV2_DYNAMICEXT = 3, //! IPluginV2DynamicExt }; //! \class IPluginV2 @@ -366,8 +353,7 @@ enum class PluginVersion : uint8_t //! \see IPluginCreator //! \see IPluginRegistry //! -class IPluginV2 -{ +class IPluginV2 { public: //! //! \brief Return the API version with which this plugin was built. @@ -375,8 +361,7 @@ class IPluginV2 //! Do not override this method as it is used by the TensorRT library to maintain backwards-compatibility with //! plugins. //! - virtual int32_t getTensorRTVersion() const TRTNOEXCEPT - { + virtual int32_t getTensorRTVersion() const TRTNOEXCEPT { return NV_TENSORRT_VERSION; } @@ -384,13 +369,13 @@ class IPluginV2 //! \brief Return the plugin type. Should match the plugin name returned by the corresponding plugin creator // \see IPluginCreator::getPluginName() //! - virtual const char* getPluginType() const TRTNOEXCEPT = 0; + virtual const char *getPluginType() const TRTNOEXCEPT = 0; //! //! \brief Return the plugin version. Should match the plugin version returned by the corresponding plugin creator // \see IPluginCreator::getPluginVersion() //! - virtual const char* getPluginVersion() const TRTNOEXCEPT = 0; + virtual const char *getPluginVersion() const TRTNOEXCEPT = 0; //! //! \brief Get the number of outputs from the layer. @@ -412,7 +397,7 @@ class IPluginV2 //! This function is called by the implementations of INetworkDefinition and IBuilder. In particular, it is called //! prior to any call to initialize(). //! - virtual Dims getOutputDimensions(int32_t index, const Dims* inputs, int32_t nbInputDims) TRTNOEXCEPT = 0; + virtual Dims getOutputDimensions(int32_t index, const Dims *inputs, int32_t nbInputDims) TRTNOEXCEPT = 0; //! //! \brief Check format support. @@ -456,8 +441,8 @@ class IPluginV2 //! //! \warning DataType:kBOOL not supported. //! - virtual void configureWithFormat(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs, - DataType type, PluginFormat format, int32_t maxBatchSize) TRTNOEXCEPT = 0; + virtual void configureWithFormat(const Dims *inputDims, int32_t nbInputs, const Dims *outputDims, int32_t nbOutputs, + DataType type, PluginFormat format, int32_t maxBatchSize) TRTNOEXCEPT = 0; //! //! \brief Initialize the layer for execution. This is called when the engine is created. @@ -493,8 +478,8 @@ class IPluginV2 //! //! \return 0 for success, else non-zero (which will cause engine termination). //! - virtual int32_t enqueue(int32_t batchSize, const void* const* inputs, void** outputs, void* workspace, - cudaStream_t stream) TRTNOEXCEPT = 0; + virtual int32_t enqueue(int32_t batchSize, const void *const *inputs, void **outputs, void *workspace, + cudaStream_t stream) TRTNOEXCEPT = 0; //! //! \brief Find the size of the serialization buffer required. @@ -510,7 +495,7 @@ class IPluginV2 //! //! \see getSerializationSize() //! - virtual void serialize(void* buffer) const TRTNOEXCEPT = 0; + virtual void serialize(void *buffer) const TRTNOEXCEPT = 0; //! //! \brief Destroy the plugin object. This will be called when the network, builder or engine is destroyed. @@ -520,21 +505,22 @@ class IPluginV2 //! //! \brief Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object with these parameters. //! - virtual IPluginV2* clone() const TRTNOEXCEPT = 0; + virtual IPluginV2 *clone() const TRTNOEXCEPT = 0; //! //! \brief Set the namespace that this plugin object belongs to. Ideally, all plugin //! objects from the same plugin library should have the same namespace. //! - virtual void setPluginNamespace(const char* pluginNamespace) TRTNOEXCEPT = 0; + virtual void setPluginNamespace(const char *pluginNamespace) TRTNOEXCEPT = 0; //! //! \brief Return the namespace of the plugin object. //! - virtual const char* getPluginNamespace() const TRTNOEXCEPT = 0; + virtual const char *getPluginNamespace() const TRTNOEXCEPT = 0; protected: - virtual ~IPluginV2() {} + virtual ~IPluginV2() { + } }; //! \class IPluginV2Ext @@ -547,8 +533,7 @@ class IPluginV2 //! //! \see IPluginV2 //! -class IPluginV2Ext : public IPluginV2 -{ +class IPluginV2Ext : public IPluginV2 { public: //! //! \brief Return the DataType of the plugin output at the requested index. @@ -558,7 +543,7 @@ class IPluginV2Ext : public IPluginV2 //! \warning DataType:kBOOL not supported. //! virtual nvinfer1::DataType getOutputDataType( - int32_t index, const nvinfer1::DataType* inputTypes, int32_t nbInputs) const TRTNOEXCEPT = 0; + int32_t index, const nvinfer1::DataType *inputTypes, int32_t nbInputs) const TRTNOEXCEPT = 0; //! \brief Return true if output tensor is broadcast across a batch. //! @@ -571,7 +556,7 @@ class IPluginV2Ext : public IPluginV2 //! physical replication of the values. //! virtual bool isOutputBroadcastAcrossBatch( - int32_t outputIndex, const bool* inputIsBroadcasted, int32_t nbInputs) const TRTNOEXCEPT = 0; + int32_t outputIndex, const bool *inputIsBroadcasted, int32_t nbInputs) const TRTNOEXCEPT = 0; //! \brief Return true if plugin can use input that is broadcast across batch without replication. //! @@ -616,11 +601,12 @@ class IPluginV2Ext : public IPluginV2 //! PluginV2IOExt or PluginV2DynamicExt for other PluginFormats. //! - virtual void configurePlugin(const Dims* inputDims, int32_t nbInputs, const Dims* outputDims, int32_t nbOutputs, - const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, - const bool* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) TRTNOEXCEPT = 0; + virtual void configurePlugin(const Dims *inputDims, int32_t nbInputs, const Dims *outputDims, int32_t nbOutputs, + const DataType *inputTypes, const DataType *outputTypes, const bool *inputIsBroadcast, + const bool *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) TRTNOEXCEPT = 0; - virtual ~IPluginV2Ext() {} + virtual ~IPluginV2Ext() { + } //! //! \brief Attach the plugin object to an execution context and grant the plugin the access to some context resource. @@ -633,7 +619,8 @@ class IPluginV2Ext : public IPluginV2 //! If the plugin needs per-context resource, it can be allocated here. //! The plugin can also get context-owned CUDNN and CUBLAS context here. //! - virtual void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) TRTNOEXCEPT {} + virtual void attachToContext(cudnnContext * /*cudnn*/, cublasContext * /*cublas*/, IGpuAllocator * /*allocator*/) TRTNOEXCEPT { + } //! //! \brief Detach the plugin object from its execution context. @@ -641,14 +628,15 @@ class IPluginV2Ext : public IPluginV2 //! This function is called automatically for each plugin when a execution context is destroyed. //! If the plugin owns per-context resource, it can be released here. //! - virtual void detachFromContext() TRTNOEXCEPT {} + virtual void detachFromContext() TRTNOEXCEPT { + } //! //! \brief Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin object with these parameters. //! If the source plugin is pre-configured with configurePlugin(), the returned object should also be pre-configured. The returned object should allow attachToContext() with a new execution context. //! Cloned plugin objects can share the same per-engine immutable resource (e.g. weights) with the source object (e.g. via ref-counting) to avoid duplication. //! - virtual IPluginV2Ext* clone() const _TENSORRT_OVERRIDE TRTNOEXCEPT = 0; + virtual IPluginV2Ext *clone() const _TENSORRT_OVERRIDE TRTNOEXCEPT = 0; protected: //! @@ -658,18 +646,16 @@ class IPluginV2Ext : public IPluginV2 //! Do not override this method as it is used by the TensorRT library to maintain backwards-compatibility with //! plugins. //! - int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE TRTNOEXCEPT - { + int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE TRTNOEXCEPT { return (static_cast(PluginVersion::kV2_EXT) << 24 | (NV_TENSORRT_VERSION & 0xFFFFFF)); } //! //! \brief Derived classes should not implement this. In a C++11 API it would be override final. //! - void configureWithFormat(const Dims* /*inputDims*/, int32_t /*nbInputs*/, const Dims* /*outputDims*/, - int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, - int32_t /*maxBatchSize*/) _TENSORRT_OVERRIDE TRTNOEXCEPT - { + void configureWithFormat(const Dims * /*inputDims*/, int32_t /*nbInputs*/, const Dims * /*outputDims*/, + int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, + int32_t /*maxBatchSize*/) _TENSORRT_OVERRIDE TRTNOEXCEPT { } }; @@ -682,8 +668,7 @@ class IPluginV2Ext : public IPluginV2 //! //! \see IPluginV2Ext //! -class IPluginV2IOExt : public IPluginV2Ext -{ +class IPluginV2IOExt : public IPluginV2Ext { public: //! //! \brief Configure the layer. @@ -697,7 +682,7 @@ class IPluginV2IOExt : public IPluginV2Ext //! \param nbOutput Number of output tensors. //! virtual void configurePlugin( - const PluginTensorDesc* in, int32_t nbInput, const PluginTensorDesc* out, int32_t nbOutput) TRTNOEXCEPT = 0; + const PluginTensorDesc *in, int32_t nbInput, const PluginTensorDesc *out, int32_t nbOutput) TRTNOEXCEPT = 0; //! //! \brief Return true if plugin supports the format and datatype for the input/output indexed by pos. @@ -732,7 +717,7 @@ class IPluginV2IOExt : public IPluginV2Ext //! Warning: TensorRT will stop asking for formats once it finds kFORMAT_COMBINATION_LIMIT on combinations. //! virtual bool supportsFormatCombination( - int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) const TRTNOEXCEPT = 0; + int32_t pos, const PluginTensorDesc *inOut, int32_t nbInputs, int32_t nbOutputs) const TRTNOEXCEPT = 0; protected: //! @@ -745,8 +730,7 @@ class IPluginV2IOExt : public IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE - { + int32_t getTensorRTVersion() const _TENSORRT_OVERRIDE { return (static_cast(PluginVersion::kV2_IOEXT) << 24 | (NV_TENSORRT_VERSION & 0xFFFFFF)); } @@ -758,8 +742,7 @@ class IPluginV2IOExt : public IPluginV2Ext //! TRT_DEPRECATED void configureWithFormat( - const Dims*, int32_t, const Dims*, int32_t, DataType, PluginFormat, int32_t) _TENSORRT_OVERRIDE _TENSORRT_FINAL - { + const Dims *, int32_t, const Dims *, int32_t, DataType, PluginFormat, int32_t) _TENSORRT_OVERRIDE _TENSORRT_FINAL { } //! @@ -769,9 +752,8 @@ class IPluginV2IOExt : public IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - void configurePlugin(const Dims*, int32_t, const Dims*, int32_t, const DataType*, const DataType*, const bool*, - const bool*, PluginFormat, int32_t) _TENSORRT_OVERRIDE _TENSORRT_FINAL - { + void configurePlugin(const Dims *, int32_t, const Dims *, int32_t, const DataType *, const DataType *, const bool *, + const bool *, PluginFormat, int32_t) _TENSORRT_OVERRIDE _TENSORRT_FINAL { } //! @@ -781,8 +763,7 @@ class IPluginV2IOExt : public IPluginV2Ext //! \deprecated Deprecated interface will be removed in TensorRT 8.0. //! TRT_DEPRECATED - bool supportsFormat(DataType, PluginFormat) const _TENSORRT_OVERRIDE _TENSORRT_FINAL - { + bool supportsFormat(DataType, PluginFormat) const _TENSORRT_OVERRIDE _TENSORRT_FINAL { return false; } }; @@ -792,16 +773,15 @@ class IPluginV2IOExt : public IPluginV2Ext //! \brief The possible field types for custom layer. //! -enum class PluginFieldType : int32_t -{ - kFLOAT16 = 0, //!< FP16 field type. - kFLOAT32 = 1, //!< FP32 field type. - kFLOAT64 = 2, //!< FP64 field type. - kINT8 = 3, //!< INT8 field type. - kINT16 = 4, //!< INT16 field type. - kINT32 = 5, //!< INT32 field type. - kCHAR = 6, //!< char field type. - kDIMS = 7, //!< nvinfer1::Dims field type. +enum class PluginFieldType : int32_t { + kFLOAT16 = 0, //!< FP16 field type. + kFLOAT32 = 1, //!< FP32 field type. + kFLOAT64 = 2, //!< FP64 field type. + kINT8 = 3, //!< INT8 field type. + kINT16 = 4, //!< INT16 field type. + kINT32 = 5, //!< INT32 field type. + kCHAR = 6, //!< char field type. + kDIMS = 7, //!< nvinfer1::Dims field type. kUNKNOWN = 8 }; @@ -812,17 +792,16 @@ enum class PluginFieldType : int32_t //! This information can be parsed to decode necessary plugin metadata //! //! -class PluginField -{ +class PluginField { public: //! //! \brief Plugin field attribute name //! - const char* name{nullptr}; + const char *name{nullptr}; //! //! \brief Plugin field attribute data //! - const void* data{nullptr}; + const void *data{nullptr}; //! //! \brief Plugin field attribute type //! \see PluginFieldType @@ -833,19 +812,14 @@ class PluginField //! int32_t length{0}; - PluginField(const char* name_ = nullptr, const void* data_ = nullptr, const PluginFieldType type_ = PluginFieldType::kUNKNOWN, int32_t length_ = 0) - : name(name_) - , data(data_) - , type(type_) - , length(length_) - { + PluginField(const char *name_ = nullptr, const void *data_ = nullptr, const PluginFieldType type_ = PluginFieldType::kUNKNOWN, int32_t length_ = 0) + : name(name_), data(data_), type(type_), length(length_) { } }; -struct PluginFieldCollection -{ - int32_t nbFields; //!< Number of PluginField entries - const PluginField* fields; //!< Pointer to PluginField entries +struct PluginFieldCollection { + int32_t nbFields; //!< Number of PluginField entries + const PluginField *fields; //!< Pointer to PluginField entries }; //! @@ -856,42 +830,40 @@ struct PluginFieldCollection //! \see IPlugin and IPluginFactory //! -class IPluginCreator -{ +class IPluginCreator { public: //! //! \brief Return the version of the API the plugin creator was compiled with. //! - virtual int32_t getTensorRTVersion() const TRTNOEXCEPT - { + virtual int32_t getTensorRTVersion() const TRTNOEXCEPT { return NV_TENSORRT_VERSION; } //! //! \brief Return the plugin name. //! - virtual const char* getPluginName() const TRTNOEXCEPT = 0; + virtual const char *getPluginName() const TRTNOEXCEPT = 0; //! //! \brief Return the plugin version. //! - virtual const char* getPluginVersion() const TRTNOEXCEPT = 0; + virtual const char *getPluginVersion() const TRTNOEXCEPT = 0; //! //! \brief Return a list of fields that needs to be passed to createPlugin. //! \see PluginFieldCollection //! - virtual const PluginFieldCollection* getFieldNames() TRTNOEXCEPT = 0; + virtual const PluginFieldCollection *getFieldNames() TRTNOEXCEPT = 0; //! //! \brief Return a plugin object. Return nullptr in case of error. //! - virtual IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) TRTNOEXCEPT = 0; + virtual IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) TRTNOEXCEPT = 0; //! //! \brief Called during deserialization of plugin layer. Return a plugin object. //! - virtual IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRTNOEXCEPT = 0; + virtual IPluginV2 *deserializePlugin(const char *name, const void *serialData, size_t serialLength) TRTNOEXCEPT = 0; //! //! \brief Set the namespace of the plugin creator based on the plugin @@ -899,14 +871,15 @@ class IPluginCreator //! //! \see IPluginRegistry::registerCreator() //! - virtual void setPluginNamespace(const char* pluginNamespace) TRTNOEXCEPT = 0; + virtual void setPluginNamespace(const char *pluginNamespace) TRTNOEXCEPT = 0; //! //! \brief Return the namespace of the plugin creator object. //! - virtual const char* getPluginNamespace() const TRTNOEXCEPT = 0; + virtual const char *getPluginNamespace() const TRTNOEXCEPT = 0; - virtual ~IPluginCreator() {} + virtual ~IPluginCreator() { + } }; //! @@ -924,29 +897,29 @@ class IPluginCreator //! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. //! -class IPluginRegistry -{ +class IPluginRegistry { public: //! //! \brief Register a plugin creator. Returns false if one with same type //! is already registered. //! - virtual bool registerCreator(IPluginCreator& creator, const char* pluginNamespace) noexcept = 0; + virtual bool registerCreator(IPluginCreator &creator, const char *pluginNamespace) noexcept = 0; //! //! \brief Return all the registered plugin creators and the number of //! registered plugin creators. Returns nullptr if none found. //! - virtual IPluginCreator* const* getPluginCreatorList(int32_t* numCreators) const noexcept = 0; + virtual IPluginCreator *const *getPluginCreatorList(int32_t *numCreators) const noexcept = 0; //! //! \brief Return plugin creator based on plugin type, version and //! namespace associated with plugin during network creation. //! - virtual IPluginCreator* getPluginCreator(const char* pluginType, const char* pluginVersion, const char* pluginNamespace = "") noexcept = 0; + virtual IPluginCreator *getPluginCreator(const char *pluginType, const char *pluginVersion, const char *pluginNamespace = "") noexcept = 0; protected: - virtual ~IPluginRegistry() noexcept {} + virtual ~IPluginRegistry() noexcept { + } public: //! @@ -961,7 +934,7 @@ class IPluginRegistry // //! \see getErrorRecorder //! - virtual void setErrorRecorder(IErrorRecorder* recorder) noexcept = 0; + virtual void setErrorRecorder(IErrorRecorder *recorder) noexcept = 0; //! //! \brief set the ErrorRecorder assigned to this interface. @@ -974,23 +947,21 @@ class IPluginRegistry //! //! \see setErrorRecorder //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; + virtual IErrorRecorder *getErrorRecorder() const noexcept = 0; }; //! //! \enum TensorLocation //! \brief The location for tensor data storage, device or host. //! -enum class TensorLocation : int32_t -{ - kDEVICE = 0, //!< Data stored on device. - kHOST = 1, //!< Data stored on host. +enum class TensorLocation : int32_t { + kDEVICE = 0, //!< Data stored on device. + kHOST = 1, //!< Data stored on host. }; //! Maximum number of elements in TensorLocation enum. \see TensorLocation -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 2; } @@ -999,8 +970,7 @@ constexpr inline int32_t EnumMax() //! //! \brief Application-implemented class for controlling allocation on the GPU. //! -class IGpuAllocator -{ +class IGpuAllocator { public: //! //! A callback implemented by the application to handle acquisition of GPU memory. @@ -1016,7 +986,7 @@ class IGpuAllocator //! //! If an allocation request cannot be satisfied, nullptr should be returned. //! - virtual void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) TRTNOEXCEPT = 0; + virtual void *allocate(uint64_t size, uint64_t alignment, uint32_t flags) TRTNOEXCEPT = 0; //! //! A callback implemented by the application to handle release of GPU memory. @@ -1025,13 +995,14 @@ class IGpuAllocator //! //! \param memory The acquired memory. //! - virtual void free(void* memory) TRTNOEXCEPT = 0; + virtual void free(void *memory) TRTNOEXCEPT = 0; //! //! Destructor declared virtual as general good practice for a class with virtual methods. //! TensorRT never calls the destructor for an IGpuAllocator defined by the application. //! - virtual ~IGpuAllocator() {} + virtual ~IGpuAllocator() { + } }; //! @@ -1042,21 +1013,19 @@ class IGpuAllocator //! Note that although a logger is passed on creation to each instance of a IBuilder or safe::IRuntime interface, the logger is internally considered a singleton, and thus //! multiple instances of safe::IRuntime and/or IBuilder must all use the same logger. //! -class ILogger -{ +class ILogger { public: //! //! \enum Severity //! //! The severity corresponding to a log message. //! - enum class Severity : int32_t - { - kINTERNAL_ERROR = 0, //!< Internal error has occurred. Execution is unrecoverable. - kERROR = 1, //!< Application error has occurred. - kWARNING = 2, //!< Application error has been discovered. TensorRT has recovered or fallen back to a default. - kINFO = 3, //!< Informational messages with instructional information. - kVERBOSE = 4, //!< Verbose messages with debugging information. + enum class Severity : int32_t { + kINTERNAL_ERROR = 0, //!< Internal error has occurred. Execution is unrecoverable. + kERROR = 1, //!< Application error has occurred. + kWARNING = 2, //!< Application error has been discovered. TensorRT has recovered or fallen back to a default. + kINFO = 3, //!< Informational messages with instructional information. + kVERBOSE = 4, //!< Verbose messages with debugging information. }; //! @@ -1065,15 +1034,15 @@ class ILogger //! \param severity The severity of the message. //! \param msg The log message, null terminated. //! - virtual void log(Severity severity, const char* msg) TRTNOEXCEPT = 0; + virtual void log(Severity severity, const char *msg) TRTNOEXCEPT = 0; - virtual ~ILogger() {} + virtual ~ILogger() { + } }; //! Maximum number of elements in ILogger::Severity enum. \see ILogger::Severity -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 5; } @@ -1082,8 +1051,7 @@ constexpr inline int32_t EnumMax() //! //! \brief Error codes that can be returned by TensorRT during execution. //! -enum class ErrorCode : int32_t -{ +enum class ErrorCode : int32_t { //! //! Execution completed successfully. //! @@ -1172,9 +1140,8 @@ enum class ErrorCode : int32_t }; //! Maximum number of elements in ErrorCode enum. \see ErrorCode -template <> -constexpr inline int32_t EnumMax() -{ +template<> +constexpr inline int32_t EnumMax() { return 11; } @@ -1199,13 +1166,12 @@ constexpr inline int32_t EnumMax() //! pushed to the interface implementation and TensorRT does not hold any synchronization primitives when accessing //! the interface functions. //! -class IErrorRecorder -{ +class IErrorRecorder { public: //! //! A typedef of a c-style string for reporting error descriptions. //! - using ErrorDesc = const char*; + using ErrorDesc = const char *; //! //! A typedef of a 32bit integer for reference counting. @@ -1326,20 +1292,20 @@ class IErrorRecorder //! virtual RefCount decRefCount() noexcept = 0; -}; // class IErrorRecorder +}; // class IErrorRecorder -} // namespace nvinfer1 +} // namespace nvinfer1 //! //! Internal C entry point for creating safe::IRuntime. //! @private //! -extern "C" TENSORRTAPI void* createSafeInferRuntime_INTERNAL(void* logger, int32_t version); +extern "C" TENSORRTAPI void *createSafeInferRuntime_INTERNAL(void *logger, int32_t version); //! //! \brief Return the logger object. //! -extern "C" TENSORRTAPI nvinfer1::ILogger* getLogger(); +extern "C" TENSORRTAPI nvinfer1::ILogger *getLogger(); //! //! \brief Return the library version number. @@ -1351,10 +1317,9 @@ extern "C" TENSORRTAPI int32_t getInferLibVersion(); //! //! \brief Return the plugin registry //! -extern "C" TENSORRTAPI nvinfer1::IPluginRegistry* getPluginRegistry(); +extern "C" TENSORRTAPI nvinfer1::IPluginRegistry *getPluginRegistry(); -namespace nvinfer1 -{ +namespace nvinfer1 { //! //! \brief Register the plugin creator to the registry @@ -1362,18 +1327,21 @@ namespace nvinfer1 //! loaded. This static object will register all creators available in the //! library to the registry. //! -template -class PluginRegistrar -{ +template +class PluginRegistrar { public: - PluginRegistrar() { getPluginRegistry()->registerCreator(instance, ""); } + PluginRegistrar() { + getPluginRegistry()->registerCreator(instance, ""); + } + private: T instance{}; }; -#define REGISTER_TENSORRT_PLUGIN(name) \ - static nvinfer1::PluginRegistrar pluginRegistrar##name {} +#define REGISTER_TENSORRT_PLUGIN(name) \ + static nvinfer1::PluginRegistrar pluginRegistrar##name { \ + } -} // namespace nvinfer1 +} // namespace nvinfer1 -#endif // NV_INFER_RUNTIME_COMMON_H +#endif // NV_INFER_RUNTIME_COMMON_H diff --git a/src/bb/dnn/NvInferVersion.h b/src/bb/dnn/NvInferVersion.h index caabc229..6ae38760 100644 --- a/src/bb/dnn/NvInferVersion.h +++ b/src/bb/dnn/NvInferVersion.h @@ -14,22 +14,22 @@ * limitations under the License. */ - //! - //! \file NvInferVersion.h - //! - //! Defines the TensorRT version - //! +//! +//! \file NvInferVersion.h +//! +//! Defines the TensorRT version +//! #ifndef NV_INFER_VERSION_H #define NV_INFER_VERSION_H -#define NV_TENSORRT_MAJOR 7 //!< TensorRT major version. -#define NV_TENSORRT_MINOR 2 //!< TensorRT minor version. -#define NV_TENSORRT_PATCH 1 //!< TensorRT patch version. -#define NV_TENSORRT_BUILD 6 //!< TensorRT build number. +#define NV_TENSORRT_MAJOR 7 //!< TensorRT major version. +#define NV_TENSORRT_MINOR 2 //!< TensorRT minor version. +#define NV_TENSORRT_PATCH 1 //!< TensorRT patch version. +#define NV_TENSORRT_BUILD 6 //!< TensorRT build number. -#define NV_TENSORRT_SONAME_MAJOR 7 //!< Shared object library major version number. -#define NV_TENSORRT_SONAME_MINOR 2 //!< Shared object library minor version number. -#define NV_TENSORRT_SONAME_PATCH 1 //!< Shared object library patch version number. +#define NV_TENSORRT_SONAME_MAJOR 7 //!< Shared object library major version number. +#define NV_TENSORRT_SONAME_MINOR 2 //!< Shared object library minor version number. +#define NV_TENSORRT_SONAME_PATCH 1 //!< Shared object library patch version number. -#endif // NV_INFER_VERSION_H +#endif // NV_INFER_VERSION_H diff --git a/src/bb/dnn/edgetpu_c.h b/src/bb/dnn/edgetpu_c.h index d3190808..9df59132 100644 --- a/src/bb/dnn/edgetpu_c.h +++ b/src/bb/dnn/edgetpu_c.h @@ -9,33 +9,33 @@ extern "C" { #endif enum edgetpu_device_type { - EDGETPU_APEX_PCI = 0, - EDGETPU_APEX_USB = 1, + EDGETPU_APEX_PCI = 0, + EDGETPU_APEX_USB = 1, }; struct edgetpu_device { - enum edgetpu_device_type type; - const char* path; + enum edgetpu_device_type type; + const char *path; }; struct edgetpu_option { - const char* name; - const char* value; + const char *name; + const char *value; }; -using edgetpu_list_devices_t = struct edgetpu_device* (*)(size_t* num_devices); -using edgetpu_free_devices_t = void (*)(struct edgetpu_device* dev); -using edgetpu_create_delegate_t = TfLiteDelegate* (*)(enum edgetpu_device_type type, const char* name, const struct edgetpu_option* options, size_t num_options); -using edgetpu_free_delegate_t = void (*)(TfLiteDelegate* delegate); +using edgetpu_list_devices_t = struct edgetpu_device *(*)(size_t *num_devices); +using edgetpu_free_devices_t = void (*)(struct edgetpu_device *dev); +using edgetpu_create_delegate_t = TfLiteDelegate *(*)(enum edgetpu_device_type type, const char *name, const struct edgetpu_option *options, size_t num_options); +using edgetpu_free_delegate_t = void (*)(TfLiteDelegate *delegate); using edgetpu_verbosity_t = void (*)(int verbosity); -using edgetpu_version_t = const char* (*)(); +using edgetpu_version_t = const char *(*)(); -edgetpu_list_devices_t edgetpu_list_devices; -edgetpu_free_devices_t edgetpu_free_devices; +edgetpu_list_devices_t edgetpu_list_devices; +edgetpu_free_devices_t edgetpu_free_devices; edgetpu_create_delegate_t edgetpu_create_delegate; -edgetpu_free_delegate_t edgetpu_free_delegate; -edgetpu_verbosity_t edgetpu_verbosity; -edgetpu_version_t edgetpu_version; +edgetpu_free_delegate_t edgetpu_free_delegate; +edgetpu_verbosity_t edgetpu_verbosity; +edgetpu_version_t edgetpu_version; bool edgetpu_init() { static ion::bb::dnn::DynamicModule dm("libedgetpu.so.1", true, true); @@ -43,12 +43,12 @@ bool edgetpu_init() { return false; } -#define RESOLVE_SYMBOL(SYM_NAME) \ - SYM_NAME = dm.get_symbol(#SYM_NAME); \ - if (SYM_NAME == nullptr) { \ - throw std::runtime_error( \ - #SYM_NAME " is unavailable on your edgetpu DSO"); \ - } +#define RESOLVE_SYMBOL(SYM_NAME) \ + SYM_NAME = dm.get_symbol(#SYM_NAME); \ + if (SYM_NAME == nullptr) { \ + throw std::runtime_error( \ + #SYM_NAME " is unavailable on your edgetpu DSO"); \ + } RESOLVE_SYMBOL(edgetpu_list_devices); RESOLVE_SYMBOL(edgetpu_free_devices); diff --git a/src/bb/dnn/httplib.h b/src/bb/dnn/httplib.h index d1d8486b..363afe7a 100644 --- a/src/bb/dnn/httplib.h +++ b/src/bb/dnn/httplib.h @@ -81,10 +81,8 @@ #endif #ifndef CPPHTTPLIB_THREAD_POOL_COUNT -#define CPPHTTPLIB_THREAD_POOL_COUNT \ - ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ - ? std::thread::hardware_concurrency() - 1 \ - : 0)) +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) #endif /* @@ -94,11 +92,11 @@ #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif //_CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE -#endif //_CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE #if defined(_MSC_VER) #ifdef _WIN64 @@ -110,19 +108,19 @@ using ssize_t = int; #if _MSC_VER < 1900 #define snprintf _snprintf_s #endif -#endif // _MSC_VER +#endif // _MSC_VER #ifndef S_ISREG #define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) -#endif // S_ISREG +#endif // S_ISREG #ifndef S_ISDIR #define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) -#endif // S_ISDIR +#endif // S_ISDIR #ifndef NOMINMAX #define NOMINMAX -#endif // NOMINMAX +#endif // NOMINMAX #include #include @@ -142,14 +140,14 @@ using ssize_t = int; #ifndef strcasecmp #define strcasecmp _stricmp -#endif // strcasecmp +#endif // strcasecmp using socket_t = SOCKET; #ifdef CPPHTTPLIB_USE_POLL #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif -#else // not _WIN32 +#else // not _WIN32 #include #include @@ -171,7 +169,7 @@ using socket_t = SOCKET; using socket_t = int; #define INVALID_SOCKET (-1) -#endif //_WIN32 +#endif //_WIN32 #include #include @@ -216,7 +214,7 @@ using socket_t = int; #if OPENSSL_VERSION_NUMBER < 0x10100000L #include inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { - return M_ASN1_STRING_data(asn1); + return M_ASN1_STRING_data(asn1); } #endif #endif @@ -238,14 +236,14 @@ namespace httplib { namespace detail { struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), s2.begin(), s2.end(), - [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; -} // namespace detail +} // namespace detail using Headers = std::multimap; @@ -258,44 +256,48 @@ struct Response; using ResponseHandler = std::function; struct MultipartFormData { - std::string name; - std::string content; - std::string filename; - std::string content_type; + std::string name; + std::string content; + std::string filename; + std::string content_type; }; using MultipartFormDataItems = std::vector; using MultipartFormDataMap = std::multimap; class DataSink { public: - DataSink() : os(&sb_), sb_(*this) {} + DataSink() + : os(&sb_), sb_(*this) { + } - DataSink(const DataSink &) = delete; - DataSink &operator=(const DataSink &) = delete; - DataSink(DataSink &&) = delete; - DataSink &operator=(DataSink &&) = delete; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; - std::function write; - std::function done; - std::function is_writable; - std::ostream os; + std::function write; + std::function done; + std::function is_writable; + std::ostream os; private: - class data_sink_streambuf : public std::streambuf { - public: - explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + class data_sink_streambuf : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) + : sink_(sink) { + } - protected: - std::streamsize xsputn(const char *s, std::streamsize n) { - sink_.write(s, static_cast(n)); - return n; - } + protected: + std::streamsize xsputn(const char *s, std::streamsize n) { + sink_.write(s, static_cast(n)); + return n; + } - private: - DataSink &sink_; - }; + private: + DataSink &sink_; + }; - data_sink_streambuf sb_; + data_sink_streambuf sb_; }; using ContentProvider = @@ -312,223 +314,231 @@ using MultipartContentHeader = class ContentReader { public: - using Reader = std::function; - using MultipartReader = std::function; + using Reader = std::function; + using MultipartReader = std::function; - ContentReader(Reader reader, MultipartReader multipart_reader) - : reader_(reader), multipart_reader_(multipart_reader) {} + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(reader), multipart_reader_(multipart_reader) { + } - bool operator()(MultipartContentHeader header, - ContentReceiver receiver) const { - return multipart_reader_(header, receiver); - } + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(header, receiver); + } - bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + bool operator()(ContentReceiver receiver) const { + return reader_(receiver); + } - Reader reader_; - MultipartReader multipart_reader_; + Reader reader_; + MultipartReader multipart_reader_; }; using Range = std::pair; using Ranges = std::vector; struct Request { - std::string method; - std::string path; - Headers headers; - std::string body; - - std::string remote_addr; - int remote_port = -1; - - // for server - std::string version; - std::string target; - Params params; - MultipartFormDataMap files; - Ranges ranges; - Match matches; - - // for client - size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; - ResponseHandler response_handler; - ContentReceiver content_receiver; - size_t content_length = 0; - ContentProvider content_provider; - Progress progress; + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + size_t content_length = 0; + ContentProvider content_provider; + Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl; + const SSL *ssl; #endif - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - bool has_param(const char *key) const; - std::string get_param_value(const char *key, size_t id = 0) const; - size_t get_param_value_count(const char *key) const; + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; - bool is_multipart_form_data() const; + bool is_multipart_form_data() const; - bool has_file(const char *key) const; - MultipartFormData get_file_value(const char *key) const; + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; - // private members... - size_t authorization_count_ = 0; + // private members... + size_t authorization_count_ = 0; }; struct Response { - std::string version; - int status = -1; - std::string reason; - Headers headers; - std::string body; - - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); - - void set_redirect(const char *url, int status = 302); - void set_redirect(const std::string &url, int status = 302); - void set_content(const char *s, size_t n, const char *content_type); - void set_content(std::string s, const char *content_type); - - void set_content_provider( - size_t length, const char *content_type, ContentProvider provider, - const std::function &resource_releaser = nullptr); - - void set_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - void set_chunked_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - Response() = default; - Response(const Response &) = default; - Response &operator=(const Response &) = default; - Response(Response &&) = default; - Response &operator=(Response &&) = default; - ~Response() { - if (content_provider_resource_releaser_) { - content_provider_resource_releaser_(); - } - } - - // private members... - size_t content_length_ = 0; - ContentProvider content_provider_; - std::function content_provider_resource_releaser_; - bool is_chunked_content_provider = false; + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_redirect(const std::string &url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, const char *content_type, ContentProvider provider, + const std::function &resource_releaser = nullptr); + + void set_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + void set_chunked_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + std::function content_provider_resource_releaser_; + bool is_chunked_content_provider = false; }; class Stream { public: - virtual ~Stream() = default; + virtual ~Stream() = default; - virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; - virtual ssize_t read(char *ptr, size_t size) = 0; - virtual ssize_t write(const char *ptr, size_t size) = 0; - virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; - template - ssize_t write_format(const char *fmt, const Args &... args); - ssize_t write(const char *ptr); - ssize_t write(const std::string &s); + template + ssize_t write_format(const char *fmt, const Args &...args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); }; class TaskQueue { public: - TaskQueue() = default; - virtual ~TaskQueue() = default; + TaskQueue() = default; + virtual ~TaskQueue() = default; - virtual void enqueue(std::function fn) = 0; - virtual void shutdown() = 0; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; - virtual void on_idle(){}; + virtual void on_idle(){}; }; class ThreadPool : public TaskQueue { public: - explicit ThreadPool(size_t n) : shutdown_(false) { - while (n) { - threads_.emplace_back(worker(*this)); - n--; + explicit ThreadPool(size_t n) + : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } } - } - - ThreadPool(const ThreadPool &) = delete; - ~ThreadPool() override = default; - void enqueue(std::function fn) override { - std::unique_lock lock(mutex_); - jobs_.push_back(fn); - cond_.notify_one(); - } + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); } - cond_.notify_all(); + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); - // Join... - for (auto &t : threads_) { - t.join(); + // Join... + for (auto &t : threads_) { + t.join(); + } } - } private: - struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + struct worker { + explicit worker(ThreadPool &pool) + : pool_(pool) { + } - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } - assert(true == static_cast(fn)); - fn(); - } - } + assert(true == static_cast(fn)); + fn(); + } + } - ThreadPool &pool_; - }; - friend struct worker; + ThreadPool &pool_; + }; + friend struct worker; - std::vector threads_; - std::list> jobs_; + std::vector threads_; + std::list> jobs_; - bool shutdown_; + bool shutdown_; - std::condition_variable cond_; - std::mutex mutex_; + std::condition_variable cond_; + std::mutex mutex_; }; using Logger = std::function; @@ -536,670 +546,687 @@ using Logger = std::function; using SocketOptions = std::function; inline void default_socket_options(socket_t sock) { - int yes = 1; + int yes = 1; #ifdef _WIN32 - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); #else #ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); #else - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); #endif #endif } class Server { public: - using Handler = std::function; - using HandlerWithContentReader = std::function; - using Expect100ContinueHandler = - std::function; + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server &Get(const char *pattern, Handler handler); - Server &Post(const char *pattern, Handler handler); - Server &Post(const char *pattern, HandlerWithContentReader handler); - Server &Put(const char *pattern, Handler handler); - Server &Put(const char *pattern, HandlerWithContentReader handler); - Server &Patch(const char *pattern, Handler handler); - Server &Patch(const char *pattern, HandlerWithContentReader handler); - Server &Delete(const char *pattern, Handler handler); - Server &Delete(const char *pattern, HandlerWithContentReader handler); - Server &Options(const char *pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); - bool set_base_dir(const char *dir, const char *mount_point = nullptr); - bool set_mount_point(const char *mount_point, const char *dir); - bool remove_mount_point(const char *mount_point); - void set_file_extension_and_mimetype_mapping(const char *ext, - const char *mime); - void set_file_request_handler(Handler handler); + bool set_base_dir(const char *dir, const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); - void set_error_handler(Handler handler); - void set_expect_100_continue_handler(Expect100ContinueHandler handler); - void set_logger(Logger logger); + void set_error_handler(Handler handler); + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + void set_logger(Logger logger); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_keep_alive_max_count(size_t count); - void set_keep_alive_timeout(time_t sec); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); - void set_idle_interval(time_t sec, time_t usec = 0); + void set_keep_alive_max_count(size_t count); + void set_keep_alive_timeout(time_t sec); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + void set_idle_interval(time_t sec, time_t usec = 0); - void set_payload_max_length(size_t length); + void set_payload_max_length(size_t length); - bool bind_to_port(const char *host, int port, int socket_flags = 0); - int bind_to_any_port(const char *host, int socket_flags = 0); - bool listen_after_bind(); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - bool listen(const char *host, int port, int socket_flags = 0); + bool listen(const char *host, int port, int socket_flags = 0); - bool is_running() const; - void stop(); + bool is_running() const; + void stop(); - std::function new_task_queue; + std::function new_task_queue; protected: - bool process_request(Stream &strm, bool close_connection, - bool &connection_closed, - const std::function &setup_request); - - std::atomic svr_sock_; - size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; - time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; - time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; - size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool process_request(Stream &strm, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; private: - using Handlers = std::vector>; - using HandlersForContentReader = - std::vector>; - - socket_t create_server_socket(const char *host, int port, int socket_flags, - SocketOptions socket_options) const; - int bind_internal(const char *host, int port, int socket_flags); - bool listen_internal(); - - bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(Request &req, Response &res, bool head = false); - bool dispatch_request(Request &req, Response &res, const Handlers &handlers); - bool - dispatch_request_for_content_reader(Request &req, Response &res, - ContentReader content_reader, - const HandlersForContentReader &handlers); - - bool parse_request_line(const char *s, Request &req); - bool write_response(Stream &strm, bool close_connection, const Request &req, - Response &res); - bool write_content_with_provider(Stream &strm, const Request &req, - Response &res, const std::string &boundary, - const std::string &content_type); - bool read_content(Stream &strm, Request &req, Response &res); - bool - read_content_with_content_receiver(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader multipart_header, - ContentReceiver multipart_receiver); - bool read_content_core(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader mulitpart_header, - ContentReceiver multipart_receiver); - - virtual bool process_and_close_socket(socket_t sock); - - std::atomic is_running_; - std::vector> base_dirs_; - std::map file_extension_and_mimetype_map_; - Handler file_request_handler_; - Handlers get_handlers_; - Handlers post_handlers_; - HandlersForContentReader post_handlers_for_content_reader_; - Handlers put_handlers_; - HandlersForContentReader put_handlers_for_content_reader_; - Handlers patch_handlers_; - HandlersForContentReader patch_handlers_for_content_reader_; - Handlers delete_handlers_; - HandlersForContentReader delete_handlers_for_content_reader_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - Expect100ContinueHandler expect_100_continue_handler_; - - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = default_socket_options; + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers); + bool + dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + const HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool close_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; + + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = default_socket_options; }; enum Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification }; class Result { public: - Result(const std::shared_ptr &res, Error err) - : res_(res), err_(err) {} - operator bool() const { return res_ != nullptr; } - bool operator==(std::nullptr_t) const { return res_ == nullptr; } - bool operator!=(std::nullptr_t) const { return res_ != nullptr; } - const Response &value() const { return *res_; } - const Response &operator*() const { return *res_; } - const Response *operator->() const { return res_.get(); } - Error error() const { return err_; } + Result(const std::shared_ptr &res, Error err) + : res_(res), err_(err) { + } + operator bool() const { + return res_ != nullptr; + } + bool operator==(std::nullptr_t) const { + return res_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return res_ != nullptr; + } + const Response &value() const { + return *res_; + } + const Response &operator*() const { + return *res_; + } + const Response *operator->() const { + return res_.get(); + } + Error error() const { + return err_; + } private: - std::shared_ptr res_; - Error err_; + std::shared_ptr res_; + Error err_; }; class ClientImpl { public: - explicit ClientImpl(const std::string &host); - - explicit ClientImpl(const std::string &host, int port); - - explicit ClientImpl(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - virtual ~ClientImpl(); - - virtual bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - - Result Patch(const char *path, const std::string &body, + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + Result Put(const char *path); + Result Put(const char *path, const std::string &body, + const char *content_type); + Result Put(const char *path, const Headers &headers, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); protected: - struct Socket { - socket_t sock = INVALID_SOCKET; + struct Socket { + socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; + SSL *ssl = nullptr; #endif - bool is_open() const { return sock != INVALID_SOCKET; } - }; + bool is_open() const { + return sock != INVALID_SOCKET; + } + }; - virtual bool create_and_connect_socket(Socket &socket); - virtual void close_socket(Socket &socket, bool process_socket_ret); + virtual bool create_and_connect_socket(Socket &socket); + virtual void close_socket(Socket &socket, bool process_socket_ret); - bool process_request(Stream &strm, const Request &req, Response &res, - bool close_connection); + bool process_request(Stream &strm, const Request &req, Response &res, + bool close_connection); - Error get_last_error() const; + Error get_last_error() const; - // Error state - mutable Error error_ = Error::Success; + // Error state + mutable Error error_ = Error::Success; - // Socket endoint information - const std::string host_; - const int port_; - const std::string host_and_port_; + // Socket endoint information + const std::string host_; + const int port_; + const std::string host_and_port_; - // Current open socket - Socket socket_; - mutable std::mutex socket_mutex_; - std::recursive_mutex request_mutex_; + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; - // Default headers - Headers default_headers_; + // Default headers + Headers default_headers_; - // Settings - std::string client_cert_path_; - std::string client_key_path_; + // Settings + std::string client_cert_path_; + std::string client_key_path_; - time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; - time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - std::string basic_auth_username_; - std::string basic_auth_password_; - std::string bearer_token_auth_token_; + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; + std::string digest_auth_username_; + std::string digest_auth_password_; #endif - bool keep_alive_ = false; - bool follow_location_ = false; + bool keep_alive_ = false; + bool follow_location_ = false; - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = nullptr; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = nullptr; - bool compress_ = false; - bool decompress_ = true; + bool compress_ = false; + bool decompress_ = true; - std::string interface_; + std::string interface_; - std::string proxy_host_; - int proxy_port_ = -1; + std::string proxy_host_; + int proxy_port_ = -1; - std::string proxy_basic_auth_username_; - std::string proxy_basic_auth_password_; - std::string proxy_bearer_token_auth_token_; + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; #endif - Logger logger_; - - void copy_settings(const ClientImpl &rhs) { - client_cert_path_ = rhs.client_cert_path_; - client_key_path_ = rhs.client_key_path_; - connection_timeout_sec_ = rhs.connection_timeout_sec_; - read_timeout_sec_ = rhs.read_timeout_sec_; - read_timeout_usec_ = rhs.read_timeout_usec_; - write_timeout_sec_ = rhs.write_timeout_sec_; - write_timeout_usec_ = rhs.write_timeout_usec_; - basic_auth_username_ = rhs.basic_auth_username_; - basic_auth_password_ = rhs.basic_auth_password_; - bearer_token_auth_token_ = rhs.bearer_token_auth_token_; + Logger logger_; + + void copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; #endif - keep_alive_ = rhs.keep_alive_; - follow_location_ = rhs.follow_location_; - tcp_nodelay_ = rhs.tcp_nodelay_; - socket_options_ = rhs.socket_options_; - compress_ = rhs.compress_; - decompress_ = rhs.decompress_; - interface_ = rhs.interface_; - proxy_host_ = rhs.proxy_host_; - proxy_port_ = rhs.proxy_port_; - proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; - proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; - proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + tcp_nodelay_ = rhs.tcp_nodelay_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; #endif - logger_ = rhs.logger_; - } + logger_ = rhs.logger_; + } private: - socket_t create_client_socket() const; - bool read_response_line(Stream &strm, Response &res); - bool write_request(Stream &strm, const Request &req, bool close_connection); - bool redirect(const Request &req, Response &res); - bool handle_request(Stream &strm, const Request &req, Response &res, - bool close_connection); - void stop_core(); - std::shared_ptr send_with_content_provider( - const char *method, const char *path, const Headers &headers, - const std::string &body, size_t content_length, - ContentProvider content_provider, const char *content_type); - - virtual bool process_socket(Socket &socket, - std::function callback); - virtual bool is_ssl() const; + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool close_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool close_connection); + void stop_core(); + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_socket(Socket &socket, + std::function callback); + virtual bool is_ssl() const; }; class Client { public: - // Universal interface - explicit Client(const char *scheme_host_port); - - explicit Client(const char *scheme_host_port, - const std::string &client_cert_path, - const std::string &client_key_path); - - // HTTP only interface - explicit Client(const std::string &host, int port); - - explicit Client(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - ~Client(); - - bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Patch(const char *path, const std::string &body, + // Universal interface + explicit Client(const char *scheme_host_port); + + explicit Client(const char *scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + ~Client(); + + bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const char *path); + Result Put(const char *path, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); - // SSL + // SSL #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Client &set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + Client &set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - Client &set_ca_cert_store(X509_STORE *ca_cert_store); + Client &set_ca_cert_store(X509_STORE *ca_cert_store); - Client &enable_server_certificate_verification(bool enabled); + Client &enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; #endif private: - std::shared_ptr cli_; + std::shared_ptr cli_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool is_ssl_ = false; + bool is_ssl_ = false; #endif -}; // namespace httplib +}; // namespace httplib #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLServer : public Server { public: - SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path = nullptr, - const char *client_ca_cert_dir_path = nullptr); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); - ~SSLServer() override; + ~SSLServer() override; - bool is_valid() const override; + bool is_valid() const override; private: - bool process_and_close_socket(socket_t sock) override; + bool process_and_close_socket(socket_t sock) override; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public ClientImpl { public: - explicit SSLClient(const std::string &host); + explicit SSLClient(const std::string &host); - explicit SSLClient(const std::string &host, int port); + explicit SSLClient(const std::string &host, int port); - explicit SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key); + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key); - ~SSLClient() override; + ~SSLClient() override; - bool is_valid() const override; + bool is_valid() const override; - void set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + void set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(X509_STORE *ca_cert_store); - void enable_server_certificate_verification(bool enabled); + void enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; private: - bool create_and_connect_socket(Socket &socket) override; - void close_socket(Socket &socket, bool process_socket_ret) override; + bool create_and_connect_socket(Socket &socket) override; + void close_socket(Socket &socket, bool process_socket_ret) override; - bool process_socket(Socket &socket, - std::function callback) override; - bool is_ssl() const override; + bool process_socket(Socket &socket, + std::function callback) override; + bool is_ssl() const override; - bool connect_with_proxy(Socket &sock, Response &res, bool &success); - bool initialize_ssl(Socket &socket); + bool connect_with_proxy(Socket &sock, Response &res, bool &success); + bool initialize_ssl(Socket &socket); - bool load_certs(); + bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; - std::once_flag initialize_cert_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; - std::vector host_components_; + std::vector host_components_; - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - X509_STORE *ca_cert_store_ = nullptr; - bool server_certificate_verification_ = true; - long verify_result_ = 0; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + X509_STORE *ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; + long verify_result_ = 0; - friend class ClientImpl; + friend class ClientImpl; }; #endif @@ -1212,724 +1239,779 @@ class SSLClient : public ClientImpl { namespace detail { inline bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; } inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { - if (i >= s.size()) { return false; } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { return false; } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } } - } - return true; + return true; } inline std::string from_i_to_hex(size_t n) { - const char *charset = "0123456789abcdef"; - std::string ret; - do { - ret = charset[n & 15] + ret; - n >>= 4; - } while (n > 0); - return ret; + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; } inline bool start_with(const std::string &a, const std::string &b) { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { - if (std::tolower(a[i]) != std::tolower(b[i])) { return false; } - } - return true; + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (std::tolower(a[i]) != std::tolower(b[i])) { + return false; + } + } + return true; } inline size_t to_utf8(int code, char *buff) { - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); - buff[1] = static_cast(0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED return 0; - } else if (code < 0x10000) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); - buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); - buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[3] = static_cast(0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; } // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c inline std::string base64_encode(const std::string &in) { - static const auto lookup = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(in.size()); + std::string out; + out.reserve(in.size()); - int val = 0; - int valb = -6; + int val = 0; + int valb = -6; - for (auto c : in) { - val = (val << 8) + static_cast(c); - valb += 8; - while (valb >= 0) { - out.push_back(lookup[(val >> valb) & 0x3F]); - valb -= 6; + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } } - } - if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } - while (out.size() % 4) { - out.push_back('='); - } + while (out.size() % 4) { + out.push_back('='); + } - return out; + return out; } inline bool is_file(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); } inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); } inline bool is_valid_path(const std::string &path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } + size_t level = 0; + size_t i = 0; - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; } - auto len = i - beg; - assert(len > 0); + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { return false; } - level--; - } else { - level++; - } + auto len = i - beg; + assert(len > 0); - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } } - } - return true; + return true; } inline std::string encode_url(const std::string &s) { - std::string result; - - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "%20"; break; - case '+': result += "%2B"; break; - case '\r': result += "%0D"; break; - case '\n': result += "%0A"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - // case ':': result += "%3A"; break; // ok? probably... - case ';': result += "%3B"; break; - default: - auto c = static_cast(s[i]); - if (c >= 0x80) { - result += '%'; - char hex[4]; - auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); - assert(len == 2); - result.append(hex, static_cast(len)); - } else { - result += s[i]; - } - break; + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } } - } - return result; + return result; } inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { result.append(buff, len); } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += static_cast(val); - i += 2; // '00' + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; } else { - result += s[i]; + result += s[i]; } - } - } else if (convert_plus_to_space && s[i] == '+') { - result += ' '; - } else { - result += s[i]; } - } - return result; + return result; } inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); } inline std::string file_extension(const std::string &path) { - std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, re)) { return m[1].str(); } - return std::string(); + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); } -inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +inline bool is_space_or_tab(char c) { + return c == ' ' || c == '\t'; +} inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { - while (b + left < e && is_space_or_tab(b[left])) { - left++; - } - while (right > 0 && is_space_or_tab(b[right - 1])) { - right--; - } - return std::make_pair(left, right); + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); } inline std::string trim_copy(const std::string &s) { - auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); - return s.substr(r.first, r.second - r.first); + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); } -template void split(const char *b, const char *e, char d, Fn fn) { - size_t i = 0; - size_t beg = 0; +template +void split(const char *b, const char *e, char d, Fn fn) { + size_t i = 0; + size_t beg = 0; - while (e ? (b + i < e) : (b[i] != '\0')) { - if (b[i] == d) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - beg = i + 1; + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + } + i++; } - i++; - } - if (i) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - } + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) - : strm_(strm), fixed_buffer_(fixed_buffer), - fixed_buffer_size_(fixed_buffer_size) {} + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) { + } - const char *ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } } - } - size_t size() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_used_size_; - } else { - return glowable_buffer_.size(); + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } } - } - bool end_with_crlf() const { - auto end = ptr() + size(); - return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; - } + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } - for (size_t i = 0;; i++) { - char byte; - auto n = strm_.read(&byte, 1); + append(byte); - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; + if (byte == '\n') { + break; + } } - } - append(byte); - - if (byte == '\n') { break; } + return true; } - return true; - } - private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } - } - - Stream &strm_; - char *fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; }; inline int close_socket(socket_t sock) { #ifdef _WIN32 - return closesocket(sock); + return closesocket(sock); #else - return close(sock); + return close(sock); #endif } -template inline ssize_t handle_EINTR(T fn) { - ssize_t res = false; - while (true) { - res = fn(); - if (res < 0 && errno == EINTR) { continue; } - break; - } - return res; +template +inline ssize_t handle_EINTR(T fn) { + ssize_t res = false; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + }); #endif } inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLOUT; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); + }); #endif } inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN | POLLOUT; - - auto timeout = static_cast(sec * 1000 + usec / 1000); - - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); - - if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { - int error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - return res >= 0 && !error; - } - return false; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + return res >= 0 && !error; + } + return false; #else - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - int error = 0; - socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len) >= 0 && - !error; - } - return false; + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); + }); + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; #endif } class SocketStream : public Stream { public: - SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec); - ~SocketStream() override; + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); + ~SocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: - SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec); - ~SSLSocketStream() override; + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #endif class BufferStream : public Stream { public: - BufferStream() = default; - ~BufferStream() override = default; + BufferStream() = default; + ~BufferStream() override = default; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; - const std::string &get_buffer() const; + const std::string &get_buffer() const; private: - std::string buffer; - size_t position = 0; + std::string buffer; + size_t position = 0; }; inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { - using namespace std::chrono; - auto start = steady_clock::now(); - while (true) { - auto val = select_read(sock, 0, 10000); - if (val < 0) { - return false; - } else if (val == 0) { - auto current = steady_clock::now(); - auto duration = duration_cast(current - start); - auto timeout = keep_alive_timeout_sec * 1000; - if (duration.count() > timeout) { return false; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } else { - return true; + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto duration = duration_cast(current - start); + auto timeout = keep_alive_timeout_sec * 1000; + if (duration.count() > timeout) { + return false; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; + } } - } } -template +template inline bool process_server_socket_core(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { - assert(keep_alive_max_count > 0); - auto ret = false; - auto count = keep_alive_max_count; - while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { - auto close_connection = count == 1; - auto connection_closed = false; - ret = callback(close_connection, connection_closed); - if (!ret || connection_closed) { break; } - count--; - } - return ret; -} - -template + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template inline bool process_server_socket(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } -template +template socket_t create_socket(const char *host, int port, int socket_flags, bool tcp_nodelay, SocketOptions socket_options, BindOrConnect bind_or_connect) { - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (getaddrinfo(host, service.c_str(), &hints, &result)) { #ifdef __linux__ - res_init(); + res_init(); #endif - return INVALID_SOCKET; - } + return INVALID_SOCKET; + } - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket #ifdef _WIN32 - auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, - nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); - /** - * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 - * and above the socket creation fails on older Windows Systems. - * - * Let's try to create a socket the old way in this case. - * - * Reference: - * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa - * - * WSA_FLAG_NO_HANDLE_INHERIT: - * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with - * SP1, and later - * - */ - if (sock == INVALID_SOCKET) { - sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - } + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } #else - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); #endif - if (sock == INVALID_SOCKET) { continue; } + if (sock == INVALID_SOCKET) { + continue; + } #ifdef __linux__ - if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + continue; + } #endif - if (tcp_nodelay) { - int yes = 1; - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), - sizeof(yes)); - } + if (tcp_nodelay) { + int yes = 1; + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), + sizeof(yes)); + } - if (socket_options) { socket_options(sock); } + if (socket_options) { + socket_options(sock); + } - if (rp->ai_family == AF_INET6) { - int no = 0; - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), - sizeof(no)); - } + if (rp->ai_family == AF_INET6) { + int no = 0; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), + sizeof(no)); + } - // bind or connect - if (bind_or_connect(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + // bind or connect + if (bind_or_connect(sock, *rp)) { + freeaddrinfo(result); + return sock; + } - close_socket(sock); - } + close_socket(sock); + } - freeaddrinfo(result); - return INVALID_SOCKET; + freeaddrinfo(result); + return INVALID_SOCKET; } inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, - nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } inline bool bind_ip_address(socket_t sock, const char *host) { - struct addrinfo hints; - struct addrinfo *result; + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; - if (getaddrinfo(host, "0", &hints, &result)) { return false; } + if (getaddrinfo(host, "0", &hints, &result)) { + return false; + } - auto ret = false; - for (auto rp = result; rp; rp = rp->ai_next) { - const auto &ai = *rp; - if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - ret = true; - break; + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } } - } - freeaddrinfo(result); - return ret; + freeaddrinfo(result); + return ret; } #if !defined _WIN32 && !defined ANDROID @@ -1938,22 +2020,22 @@ inline bool bind_ip_address(socket_t sock, const char *host) { #ifdef USE_IF2IP inline std::string if2ip(const std::string &ifn) { - struct ifaddrs *ifap; - getifaddrs(&ifap); - for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { - if (ifa->ifa_addr && ifn == ifa->ifa_name) { - if (ifa->ifa_addr->sa_family == AF_INET) { - auto sa = reinterpret_cast(ifa->ifa_addr); - char buf[INET_ADDRSTRLEN]; - if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { - freeifaddrs(ifap); - return std::string(buf, INET_ADDRSTRLEN); - } - } - } - } - freeifaddrs(ifap); - return std::string(); + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); + return std::string(); } #endif @@ -1962,1346 +2044,1544 @@ inline socket_t create_client_socket(const char *host, int port, SocketOptions socket_options, time_t timeout_sec, time_t timeout_usec, const std::string &intf, Error &error) { - auto sock = create_socket( - host, port, 0, tcp_nodelay, socket_options, - [&](socket_t sock, struct addrinfo &ai) -> bool { - if (!intf.empty()) { + auto sock = create_socket( + host, port, 0, tcp_nodelay, socket_options, + [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { #ifdef USE_IF2IP - auto ip = if2ip(intf); - if (ip.empty()) { ip = intf; } - if (!bind_ip_address(sock, ip.c_str())) { - error = Error::BindIPAddress; - return false; - } + auto ip = if2ip(intf); + if (ip.empty()) { + ip = intf; + } + if (!bind_ip_address(sock, ip.c_str())) { + error = Error::BindIPAddress; + return false; + } #endif - } + } - set_nonblocking(sock, true); + set_nonblocking(sock, true); - auto ret = - ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); - if (ret < 0) { - if (is_connection_error() || - !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { - close_socket(sock); - error = Error::Connection; - return false; - } - } + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { + close_socket(sock); + error = Error::Connection; + return false; + } + } - set_nonblocking(sock, false); - error = Error::Success; - return true; - }); + set_nonblocking(sock, false); + error = Error::Success; + return true; + }); - if (sock != INVALID_SOCKET) { - error = Error::Success; - } else { - if (error == Error::Success) { error = Error::Connection; } - } + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } - return sock; + return sock; } inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { - if (addr.ss_family == AF_INET) { - port = ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - port = - ntohs(reinterpret_cast(&addr)->sin6_port); - } + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } - std::array ipstr{}; - if (!getnameinfo(reinterpret_cast(&addr), addr_len, - ipstr.data(), static_cast(ipstr.size()), nullptr, - 0, NI_NUMERICHOST)) { - ip = ipstr.data(); - } + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } } inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); - if (!getpeername(sock, reinterpret_cast(&addr), - &addr_len)) { - get_remote_ip_and_port(addr, addr_len, ip, port); - } + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } } inline const char * find_content_type(const std::string &path, const std::map &user_data) { - auto ext = file_extension(path); - - auto it = user_data.find(ext); - if (it != user_data.end()) { return it->second.c_str(); } - - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html" || ext == "htm") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "wasm") { - return "application/wasm"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second.c_str(); + } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; } inline const char *status_message(int status) { - switch (status) { - case 100: return "Continue"; - case 101: return "Switching Protocol"; - case 102: return "Processing"; - case 103: return "Early Hints"; - case 200: return "OK"; - case 201: return "Created"; - case 202: return "Accepted"; - case 203: return "Non-Authoritative Information"; - case 204: return "No Content"; - case 205: return "Reset Content"; - case 206: return "Partial Content"; - case 207: return "Multi-Status"; - case 208: return "Already Reported"; - case 226: return "IM Used"; - case 300: return "Multiple Choice"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 305: return "Use Proxy"; - case 306: return "unused"; - case 307: return "Temporary Redirect"; - case 308: return "Permanent Redirect"; - case 400: return "Bad Request"; - case 401: return "Unauthorized"; - case 402: return "Payment Required"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 405: return "Method Not Allowed"; - case 406: return "Not Acceptable"; - case 407: return "Proxy Authentication Required"; - case 408: return "Request Timeout"; - case 409: return "Conflict"; - case 410: return "Gone"; - case 411: return "Length Required"; - case 412: return "Precondition Failed"; - case 413: return "Payload Too Large"; - case 414: return "URI Too Long"; - case 415: return "Unsupported Media Type"; - case 416: return "Range Not Satisfiable"; - case 417: return "Expectation Failed"; - case 418: return "I'm a teapot"; - case 421: return "Misdirected Request"; - case 422: return "Unprocessable Entity"; - case 423: return "Locked"; - case 424: return "Failed Dependency"; - case 425: return "Too Early"; - case 426: return "Upgrade Required"; - case 428: return "Precondition Required"; - case 429: return "Too Many Requests"; - case 431: return "Request Header Fields Too Large"; - case 451: return "Unavailable For Legal Reasons"; - case 501: return "Not Implemented"; - case 502: return "Bad Gateway"; - case 503: return "Service Unavailable"; - case 504: return "Gateway Timeout"; - case 505: return "HTTP Version Not Supported"; - case 506: return "Variant Also Negotiates"; - case 507: return "Insufficient Storage"; - case 508: return "Loop Detected"; - case 510: return "Not Extended"; - case 511: return "Network Authentication Required"; - - default: - case 500: return "Internal Server Error"; - } + switch (status) { + case 100: + return "Continue"; + case 101: + return "Switching Protocol"; + case 102: + return "Processing"; + case 103: + return "Early Hints"; + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 203: + return "Non-Authoritative Information"; + case 204: + return "No Content"; + case 205: + return "Reset Content"; + case 206: + return "Partial Content"; + case 207: + return "Multi-Status"; + case 208: + return "Already Reported"; + case 226: + return "IM Used"; + case 300: + return "Multiple Choice"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 303: + return "See Other"; + case 304: + return "Not Modified"; + case 305: + return "Use Proxy"; + case 306: + return "unused"; + case 307: + return "Temporary Redirect"; + case 308: + return "Permanent Redirect"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 402: + return "Payment Required"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Payload Too Large"; + case 414: + return "URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Range Not Satisfiable"; + case 417: + return "Expectation Failed"; + case 418: + return "I'm a teapot"; + case 421: + return "Misdirected Request"; + case 422: + return "Unprocessable Entity"; + case 423: + return "Locked"; + case 424: + return "Failed Dependency"; + case 425: + return "Too Early"; + case 426: + return "Upgrade Required"; + case 428: + return "Precondition Required"; + case 429: + return "Too Many Requests"; + case 431: + return "Request Header Fields Too Large"; + case 451: + return "Unavailable For Legal Reasons"; + case 501: + return "Not Implemented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + case 506: + return "Variant Also Negotiates"; + case 507: + return "Insufficient Storage"; + case 508: + return "Loop Detected"; + case 510: + return "Not Extended"; + case 511: + return "Network Authentication Required"; + + default: + case 500: + return "Internal Server Error"; + } } inline bool can_compress_content_type(const std::string &content_type) { - return (!content_type.find("text/") && content_type != "text/event-stream") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; + return (!content_type.find("text/") && content_type != "text/event-stream") || + content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; } -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, + Gzip, + Brotli }; inline EncodingType encoding_type(const Request &req, const Response &res) { - auto ret = - detail::can_compress_content_type(res.get_header_value("Content-Type")); - if (!ret) { return EncodingType::None; } + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } - const auto &s = req.get_header_value("Accept-Encoding"); - (void)(s); + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); #ifdef CPPHTTPLIB_BROTLI_SUPPORT - // TODO: 'Accept-Encoding' has br, not br;q=0 - ret = s.find("br") != std::string::npos; - if (ret) { return EncodingType::Brotli; } + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 - ret = s.find("gzip") != std::string::npos; - if (ret) { return EncodingType::Gzip; } + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } #endif - return EncodingType::None; + return EncodingType::None; } class compressor { public: - virtual ~compressor(){}; + virtual ~compressor(){}; - typedef std::function Callback; - virtual bool compress(const char *data, size_t data_length, bool last, - Callback callback) = 0; + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; }; class decompressor { public: - virtual ~decompressor() {} + virtual ~decompressor() { + } - virtual bool is_valid() const = 0; + virtual bool is_valid() const = 0; - typedef std::function Callback; - virtual bool decompress(const char *data, size_t data_length, - Callback callback) = 0; + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; }; class nocompressor : public compressor { public: - ~nocompressor(){}; + ~nocompressor(){}; - bool compress(const char *data, size_t data_length, bool /*last*/, - Callback callback) override { - if (!data_length) { return true; } - return callback(data, data_length); - } + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override { + if (!data_length) { + return true; + } + return callback(data, data_length); + } }; #ifdef CPPHTTPLIB_ZLIB_SUPPORT class gzip_compressor : public compressor { public: - gzip_compressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, - Z_DEFAULT_STRATEGY) == Z_OK; - } + gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; + } - ~gzip_compressor() { deflateEnd(&strm_); } + ~gzip_compressor() { + deflateEnd(&strm_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - assert(is_valid_); + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + assert(is_valid_); - auto flush = last ? Z_FINISH : Z_NO_FLUSH; + auto flush = last ? Z_FINISH : Z_NO_FLUSH; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - int ret = Z_OK; + int ret = Z_OK; - std::array buff{}; - do { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + do { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = deflate(&strm_, flush); - assert(ret != Z_STREAM_ERROR); + ret = deflate(&strm_, flush); + assert(ret != Z_STREAM_ERROR); - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } while (strm_.avail_out == 0); + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); - assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); - assert(strm_.avail_in == 0); - return true; - } + assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); + assert(strm_.avail_in == 0); + return true; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; class gzip_decompressor : public decompressor { public: - gzip_decompressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - // 15 is the value of wbits, which should be at the maximum possible value - // to ensure that any gzip stream can be decoded. The offset of 32 specifies - // that the stream type should be automatically detected either gzip or - // deflate. - is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; - } + gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; + } - ~gzip_decompressor() { inflateEnd(&strm_); } + ~gzip_decompressor() { + inflateEnd(&strm_); + } - bool is_valid() const override { return is_valid_; } + bool is_valid() const override { + return is_valid_; + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - assert(is_valid_); + bool decompress(const char *data, size_t data_length, + Callback callback) override { + assert(is_valid_); - int ret = Z_OK; + int ret = Z_OK; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - std::array buff{}; - while (strm_.avail_in > 0) { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + while (strm_.avail_in > 0) { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = inflate(&strm_, Z_NO_FLUSH); - assert(ret != Z_STREAM_ERROR); - switch (ret) { - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: inflateEnd(&strm_); return false; - } + ret = inflate(&strm_, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } - return ret == Z_OK || ret == Z_STREAM_END; - } + return ret == Z_OK || ret == Z_STREAM_END; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; #endif #ifdef CPPHTTPLIB_BROTLI_SUPPORT class brotli_compressor : public compressor { public: - brotli_compressor() { - state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); - } + brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + } - ~brotli_compressor() { BrotliEncoderDestroyInstance(state_); } + ~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - std::array buff{}; + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } - auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; - auto available_in = data_length; - auto next_in = reinterpret_cast(data); + auto available_out = buff.size(); + auto next_out = buff.data(); - for (;;) { - if (last) { - if (BrotliEncoderIsFinished(state_)) { break; } - } else { - if (!available_in) { break; } - } - - auto available_out = buff.size(); - auto next_out = buff.data(); - - if (!BrotliEncoderCompressStream(state_, operation, &available_in, - &next_in, &available_out, &next_out, - nullptr)) { - return false; - } + if (!BrotliEncoderCompressStream(state_, operation, &available_in, + &next_in, &available_out, &next_out, + nullptr)) { + return false; + } - auto output_bytes = buff.size() - available_out; - if (output_bytes) { - callback(reinterpret_cast(buff.data()), output_bytes); - } - } + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } - return true; - } + return true; + } private: - BrotliEncoderState *state_ = nullptr; + BrotliEncoderState *state_ = nullptr; }; class brotli_decompressor : public decompressor { public: - brotli_decompressor() { - decoder_s = BrotliDecoderCreateInstance(0, 0, 0); - decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT - : BROTLI_DECODER_RESULT_ERROR; - } - - ~brotli_decompressor() { - if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } - } + brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; + } - bool is_valid() const override { return decoder_s; } + ~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_ERROR) { - return 0; + bool is_valid() const override { + return decoder_s; } - const uint8_t *next_in = (const uint8_t *)data; - size_t avail_in = data_length; - size_t total_out; + bool decompress(const char *data, size_t data_length, + Callback callback) override { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } - decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + const uint8_t *next_in = (const uint8_t *)data; + size_t avail_in = data_length; + size_t total_out; - std::array buff{}; - while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { - char *next_out = buff.data(); - size_t avail_out = buff.size(); + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; - decoder_r = BrotliDecoderDecompressStream( - decoder_s, &avail_in, &next_in, &avail_out, - reinterpret_cast(&next_out), &total_out); + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); - if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); - if (!callback(buff.data(), buff.size() - avail_out)) { return false; } - } + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } - return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; - } + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; + } private: - BrotliDecoderResult decoder_r; - BrotliDecoderState *decoder_s = nullptr; + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; }; #endif inline bool has_header(const Headers &headers, const char *key) { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline const char *get_header_value(const Headers &headers, const char *key, size_t id = 0, const char *def = nullptr) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second.c_str(); } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; } -template +template inline T get_header_value(const Headers & /*headers*/, const char * /*key*/, - size_t /*id*/ = 0, uint64_t /*def*/ = 0) {} + size_t /*id*/ = 0, uint64_t /*def*/ = 0) { +} -template <> +template<> inline uint64_t get_header_value(const Headers &headers, const char *key, size_t id, uint64_t def) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); - } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; } -template +template inline bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } - auto p = beg; - while (p < end && *p != ':') { - p++; - } + auto p = beg; + while (p < end && *p != ':') { + p++; + } - if (p == end) { return false; } + if (p == end) { + return false; + } - auto key_end = p; + auto key_end = p; - if (*p++ != ':') { return false; } + if (*p++ != ':') { + return false; + } - while (p < end && is_space_or_tab(*p)) { - p++; - } + while (p < end && is_space_or_tab(*p)) { + p++; + } - if (p < end) { - fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); - return true; - } + if (p < end) { + fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); + return true; + } - return false; + return false; } inline bool read_headers(Stream &strm, Headers &headers) { - const auto bufsiz = 2048; - char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); - for (;;) { - if (!line_reader.getline()) { return false; } + for (;;) { + if (!line_reader.getline()) { + return false; + } - // Check if the line ends with CRLF. - if (line_reader.end_with_crlf()) { - // Blank line indicates end of headers. - if (line_reader.size() == 2) { break; } - } else { - continue; // Skip invalid line. - } + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } + } else { + continue; // Skip invalid line. + } - // Exclude CRLF - auto end = line_reader.ptr() + line_reader.size() - 2; + // Exclude CRLF + auto end = line_reader.ptr() + line_reader.size() - 2; - parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - headers.emplace(std::move(key), std::move(val)); - }); - } + parse_header(line_reader.ptr(), end, + [&](std::string &&key, std::string &&val) { + headers.emplace(std::move(key), std::move(val)); + }); + } - return true; + return true; } inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return false; } + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } - if (!out(buf, static_cast(n))) { return false; } + if (!out(buf, static_cast(n))) { + return false; + } - r += static_cast(n); + r += static_cast(n); - if (progress) { - if (!progress(r, len)) { return false; } + if (progress) { + if (!progress(r, len)) { + return false; + } + } } - } - return true; + return true; } inline void skip_content_with_length(Stream &strm, uint64_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } } inline bool read_content_without_length(Stream &strm, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - for (;;) { - auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n < 0) { - return false; - } else if (n == 0) { - return true; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { + return false; + } } - if (!out(buf, static_cast(n))) { return false; } - } - return true; + return true; } inline bool read_content_chunked(Stream &strm, ContentReceiver out) { - const auto bufsiz = 16; - char buf[bufsiz]; + const auto bufsiz = 16; + char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + stream_line_reader line_reader(strm, buf, bufsiz); - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - unsigned long chunk_len; - while (true) { - char *end_ptr; + unsigned long chunk_len; + while (true) { + char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); - if (end_ptr == line_reader.ptr()) { return false; } - if (chunk_len == ULONG_MAX) { return false; } + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } - if (chunk_len == 0) { break; } + if (chunk_len == 0) { + break; + } - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { - return false; - } + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - if (strcmp(line_reader.ptr(), "\r\n")) { break; } + if (strcmp(line_reader.ptr(), "\r\n")) { + break; + } - if (!line_reader.getline()) { return false; } - } + if (!line_reader.getline()) { + return false; + } + } - if (chunk_len == 0) { - // Reader terminator after chunks - if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) - return false; - } + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } - return true; + return true; } inline bool is_chunked_transfer_encoding(const Headers &headers) { - return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), - "chunked"); + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); } -template +template bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, bool decompress, U callback) { - if (decompress) { - std::string encoding = x.get_header_value("Content-Encoding"); - std::shared_ptr decompressor; + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::shared_ptr decompressor; - if (encoding.find("gzip") != std::string::npos || - encoding.find("deflate") != std::string::npos) { + if (encoding.find("gzip") != std::string::npos || + encoding.find("deflate") != std::string::npos) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } else if (encoding.find("br") != std::string::npos) { + } else if (encoding.find("br") != std::string::npos) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } + } - if (decompressor) { - if (decompressor->is_valid()) { - ContentReceiver out = [&](const char *buf, size_t n) { - return decompressor->decompress( - buf, n, - [&](const char *buf, size_t n) { return receiver(buf, n); }); - }; - return callback(out); - } else { - status = 500; - return false; - } + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiver out = [&](const char *buf, size_t n) { + return decompressor->decompress( + buf, n, + [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + return callback(out); + } else { + status = 500; + return false; + } + } } - } - ContentReceiver out = [&](const char *buf, size_t n) { - return receiver(buf, n); - }; - return callback(out); + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + return callback(out); } -template +template bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, ContentReceiver receiver, bool decompress) { - return prepare_content_receiver( - x, status, receiver, decompress, [&](const ContentReceiver &out) { - auto ret = true; - auto exceed_payload_max_length = false; - - if (is_chunked_transfer_encoding(x.headers)) { - ret = read_content_chunked(strm, out); - } else if (!has_header(x.headers, "Content-Length")) { - ret = read_content_without_length(strm, out); - } else { - auto len = get_header_value(x.headers, "Content-Length"); - if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; - } else if (len > 0) { - ret = read_content_with_length(strm, len, progress, out); - } - } + return prepare_content_receiver( + x, status, receiver, decompress, [&](const ContentReceiver &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value(x.headers, "Content-Length"); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } - if (!ret) { status = exceed_payload_max_length ? 413 : 400; } - return ret; - }); + if (!ret) { + status = exceed_payload_max_length ? 413 : 400; + } + return ret; + }); } -template +template inline ssize_t write_headers(Stream &strm, const T &info, const Headers &headers) { - ssize_t write_len = 0; - for (const auto &x : info.headers) { - if (x.first == "EXCEPTION_WHAT") { continue; } - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } - for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } + ssize_t write_len = 0; + for (const auto &x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { + continue; + } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } write_len += len; - } - auto len = strm.write("\r\n"); - if (len < 0) { return len; } - write_len += len; - return write_len; + return write_len; } inline bool write_data(Stream &strm, const char *d, size_t l) { - size_t offset = 0; - while (offset < l) { - auto length = strm.write(d + offset, l - offset); - if (length < 0) { return false; } - offset += static_cast(length); - } - return true; + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; } -template +template inline ssize_t write_content(Stream &strm, ContentProvider content_provider, size_t offset, size_t length, T is_shutting_down) { - size_t begin_offset = offset; - size_t end_offset = offset + length; - auto ok = true; - DataSink data_sink; + size_t begin_offset = offset; + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (offset < end_offset && !is_shutting_down()) { - if (!content_provider(offset, end_offset - offset, data_sink)) { - return -1; + while (offset < end_offset && !is_shutting_down()) { + if (!content_provider(offset, end_offset - offset, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } } - if (!ok) { return -1; } - } - return static_cast(offset - begin_offset); + return static_cast(offset - begin_offset); } -template +template inline ssize_t write_content_without_length(Stream &strm, ContentProvider content_provider, T is_shutting_down) { - size_t offset = 0; - auto data_available = true; - auto ok = true; - DataSink data_sink; + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.done = [&](void) { data_available = false; }; + data_sink.done = [&](void) { data_available = false; }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } - return static_cast(offset); + return static_cast(offset); } -template +template inline ssize_t write_content_chunked(Stream &strm, ContentProvider content_provider, T is_shutting_down, U &compressor) { - size_t offset = 0; - auto data_available = true; - ssize_t total_written_length = 0; - auto ok = true; - DataSink data_sink; - - data_sink.write = [&](const char *d, size_t l) { - if (!ok) { return; } - - data_available = l > 0; - offset += l; - - std::string payload; - if (!compressor.compress(d, l, false, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - }; - - data_sink.done = [&](void) { - if (!ok) { return; } - - data_available = false; - - std::string payload; - if (!compressor.compress(nullptr, 0, true, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - - static const std::string done_marker("0\r\n\r\n"); - if (write_data(strm, done_marker.data(), done_marker.size())) { - total_written_length += done_marker.size(); - } else { - ok = false; - } - }; + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + auto ok = true; + DataSink data_sink; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.write = [&](const char *d, size_t l) { + if (!ok) { + return; + } + + data_available = l > 0; + offset += l; + + std::string payload; + if (!compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } + }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + data_sink.done = [&](void) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } - return total_written_length; + static const std::string done_marker("0\r\n\r\n"); + if (write_data(strm, done_marker.data(), done_marker.size())) { + total_written_length += done_marker.size(); + } else { + ok = false; + } + }; + + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } + + return total_written_length; } -template +template inline bool redirect(T &cli, const Request &req, Response &res, const std::string &path) { - Request new_req = req; - new_req.path = path; - new_req.redirect_count -= 1; - - if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { - new_req.method = "GET"; - new_req.body.clear(); - new_req.headers.clear(); - } + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } - Response new_res; + Response new_res; - auto ret = cli.send(new_req, new_res); - if (ret) { res = new_res; } - return ret; + auto ret = cli.send(new_req, new_res); + if (ret) { + res = new_res; + } + return ret; } inline std::string params_to_query_str(const Params ¶ms) { - std::string query; + std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } - query += it->first; - query += "="; - query += encode_url(it->second); - } - return query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_url(it->second); + } + return query; } inline void parse_query_text(const std::string &s, Params ¶ms) { - split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char *b2, const char *e2) { - if (key.empty()) { - key.assign(b2, e2); - } else { - val.assign(b2, e2); - } - }); + split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); - if (!key.empty()) { - params.emplace(decode_url(key, true), decode_url(val, true)); - } - }); + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); } inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { return false; } - boundary = content_type.substr(pos + 9); - if (boundary.length() >= 2 && boundary.front() == '"' && - boundary.back() == '"') { - boundary = boundary.substr(1, boundary.size() - 2); - } - return !boundary.empty(); + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { + return false; + } + boundary = content_type.substr(pos + 9); + if (boundary.length() >= 2 && boundary.front() == '"' && + boundary.back() == '"') { + boundary = boundary.substr(1, boundary.size() - 2); + } + return !boundary.empty(); } inline bool parse_range_header(const std::string &s, Ranges &ranges) { - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re_first_range)) { - auto pos = static_cast(m.position(1)); - auto len = static_cast(m.length(1)); - bool all_valid_ranges = true; - split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { - if (!all_valid_ranges) return; - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch cm; - if (std::regex_match(b, e, cm, re_another_range)) { - ssize_t first = -1; - if (!cm.str(1).empty()) { - first = static_cast(std::stoll(cm.str(1))); - } - - ssize_t last = -1; - if (!cm.str(2).empty()) { - last = static_cast(std::stoll(cm.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - all_valid_ranges = false; - return; - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); - return all_valid_ranges; - } - return false; + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; } class MultipartFormDataParser { public: - MultipartFormDataParser() = default; - - void set_boundary(std::string &&boundary) { boundary_ = boundary; } - - bool is_valid() const { return is_valid_; } - - template - bool parse(const char *buf, size_t n, T content_callback, U header_callback) { - - static const std::regex re_content_disposition( - "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" - "\"(.*?)\")?\\s*$", - std::regex_constants::icase); - static const std::string dash_ = "--"; - static const std::string crlf_ = "\r\n"; - - buf_.append(buf, n); // TODO: performance improvement - - while (!buf_.empty()) { - switch (state_) { - case 0: { // Initial boundary - auto pattern = dash_ + boundary_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - auto pos = buf_.find(pattern); - if (pos != 0) { return false; } - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - state_ = 1; - break; - } - case 1: { // New entry - clear_file_info(); - state_ = 2; - break; - } - case 2: { // Headers - auto pos = buf_.find(crlf_); - while (pos != std::string::npos) { - // Empty line - if (pos == 0) { - if (!header_callback(file_)) { - is_valid_ = false; - return false; - } - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 3; - break; - } - - static const std::string header_name = "content-type:"; - const auto header = buf_.substr(0, pos); - if (start_with(header, header_name)) { - file_.content_type = trim_copy(header.substr(header_name.size())); - } else { - std::smatch m; - if (std::regex_match(header, m, re_content_disposition)) { - file_.name = m[1]; - file_.filename = m[2]; - } - } - - buf_.erase(0, pos + crlf_.size()); - off_ += pos + crlf_.size(); - pos = buf_.find(crlf_); - } - if (state_ != 3) { return true; } - break; - } - case 3: { // Body - { - auto pattern = crlf_ + dash_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos == std::string::npos) { - pos = buf_.size(); - while (pos > 0) { - auto c = buf_[pos - 1]; - if (c != '\r' && c != '\n' && c != '-') { break; } - pos--; - } - } + MultipartFormDataParser() = default; - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; - } + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + } - off_ += pos; - buf_.erase(0, pos); - } + bool is_valid() const { + return is_valid_; + } - { - auto pattern = crlf_ + dash_ + boundary_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos != std::string::npos) { - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + static const std::string dash_ = "--"; + static const std::string crlf_ = "\r\n"; + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + auto pos = buf_.find(pattern); + if (pos != 0) { + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + static const std::string header_name = "content-type:"; + const auto header = buf_.substr(0, pos); + if (start_with(header, header_name)) { + file_.content_type = trim_copy(header.substr(header_name.size())); + } else { + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { + pos = buf_.size(); + while (pos > 0) { + auto c = buf_[pos - 1]; + if (c != '\r' && c != '\n' && c != '-') { + break; + } + pos--; + } + } + + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, crlf_.size(), crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, pattern.size(), pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; } - - off_ += pos + pattern.size(); - buf_.erase(0, pos + pattern.size()); - state_ = 4; - } else { - if (!content_callback(buf_.data(), pattern.size())) { - is_valid_ = false; - return false; } - - off_ += pattern.size(); - buf_.erase(0, pattern.size()); - } - } - break; - } - case 4: { // Boundary - if (crlf_.size() > buf_.size()) { return true; } - if (buf_.compare(0, crlf_.size(), crlf_) == 0) { - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 1; - } else { - auto pattern = dash_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - if (buf_.compare(0, pattern.size(), pattern) == 0) { - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - is_valid_ = true; - state_ = 5; - } else { - return true; - } } - break; - } - case 5: { // Done - is_valid_ = false; - return false; - } - } - } - return true; - } + return true; + } private: - void clear_file_info() { - file_.name.clear(); - file_.filename.clear(); - file_.content_type.clear(); - } - - std::string boundary_; - - std::string buf_; - size_t state_ = 0; - bool is_valid_ = false; - size_t off_ = 0; - MultipartFormData file_; + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + bool is_valid_ = false; + size_t off_ = 0; + MultipartFormData file_; }; inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; } inline std::string make_multipart_data_boundary() { - static const char data[] = - "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - std::random_device seed_gen; - std::mt19937 engine(seed_gen()); + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); - std::string result = "--cpp-httplib-multipart-data-"; + std::string result = "--cpp-httplib-multipart-data-"; - for (auto i = 0; i < 16; i++) { - result += data[engine() % (sizeof(data) - 1)]; - } + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } - return result; + return result; } inline std::pair get_range_offset_and_length(const Request &req, size_t content_length, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.first == -1 && r.second == -1) { - return std::make_pair(0, content_length); - } + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } - auto slen = static_cast(content_length); + auto slen = static_cast(content_length); - if (r.first == -1) { - r.first = slen - r.second; - r.second = slen - 1; - } + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } - if (r.second == -1) { r.second = slen - 1; } + if (r.second == -1) { + r.second = slen - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) { - std::string field = "bytes "; - field += std::to_string(offset); - field += "-"; - field += std::to_string(offset + length - 1); - field += "/"; - field += std::to_string(content_length); - return field; + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; } -template +template bool process_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type, SToken stoken, CToken ctoken, Content content) { - for (size_t i = 0; i < req.ranges.size(); i++) { - ctoken("--"); - stoken(boundary); - ctoken("\r\n"); - if (!content_type.empty()) { - ctoken("Content-Type: "); - stoken(content_type); - ctoken("\r\n"); - } + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } - auto offsets = get_range_offset_and_length(req, res.body.size(), i); - auto offset = offsets.first; - auto length = offsets.second; + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; - ctoken("Content-Range: "); - stoken(make_content_range_header_field(offset, length, res.body.size())); - ctoken("\r\n"); - ctoken("\r\n"); - if (!content(offset, length)) { return false; } - ctoken("\r\n"); - } + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { + return false; + } + ctoken("\r\n"); + } - ctoken("--"); - stoken(boundary); - ctoken("--\r\n"); + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); - return true; + return true; } inline std::string make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - std::string data; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data += token; }, - [&](const char *token) { data += token; }, - [&](size_t offset, size_t length) { - data += res.body.substr(offset, length); - return true; - }); + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); - return data; + return data; } inline size_t get_multipart_ranges_data_length(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - size_t data_length = 0; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data_length += token.size(); }, - [&](const char *token) { data_length += strlen(token); }, - [&](size_t /*offset*/, size_t length) { - data_length += length; - return true; - }); + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); - return data_length; + return data_length; } -template +template inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type, T is_shutting_down) { - return process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { strm.write(token); }, - [&](const char *token) { strm.write(token); }, - [&](size_t offset, size_t length) { - return write_content(strm, res.content_provider_, offset, length, - is_shutting_down) >= 0; - }); + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down) >= 0; + }); } inline std::pair get_range_offset_and_length(const Request &req, const Response &res, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.second == -1) { - r.second = static_cast(res.content_length_) - 1; - } + if (r.second == -1) { + r.second = static_cast(res.content_length_) - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline bool expect_content(const Request &req) { - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || - req.method == "PRI" || req.method == "DELETE") { - return true; - } - // TODO: check if Content-Length is set - return false; + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; } inline bool has_crlf(const char *s) { - auto p = s; - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; - } - return false; + auto p = s; + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -template +template inline std::string message_digest(const std::string &s, Init init, Update update, Final final, size_t digest_length) { - using namespace std; + using namespace std; - std::vector md(digest_length, 0); - CTX ctx; - init(&ctx); - update(&ctx, s.data(), s.size()); - final(md.data(), &ctx); + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); - stringstream ss; - for (auto c : md) { - ss << setfill('0') << setw(2) << hex << (unsigned int)c; - } - return ss.str(); + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); } inline std::string MD5(const std::string &s) { - return message_digest(s, MD5_Init, MD5_Update, MD5_Final, - MD5_DIGEST_LENGTH); + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); } inline std::string SHA_256(const std::string &s) { - return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, - SHA256_DIGEST_LENGTH); + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); } inline std::string SHA_512(const std::string &s) { - return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, - SHA512_DIGEST_LENGTH); + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); } #endif @@ -3310,37 +3590,41 @@ inline std::string SHA_512(const std::string &s) { // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store inline bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } + if (!hStore) { + return false; + } - PCCERT_CONTEXT pContext = NULL; - while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); + PCCERT_CONTEXT pContext = NULL; + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } } - } - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); - return true; + return true; } #endif class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { WSACleanup(); } + ~WSInit() { + WSACleanup(); + } }; static WSInit wsinit_; @@ -3351,340 +3635,355 @@ inline std::pair make_digest_authentication_header( const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, const std::string &username, const std::string &password, bool is_proxy = false) { - using namespace std; + using namespace std; - string nc; - { - stringstream ss; - ss << setfill('0') << setw(8) << hex << cnonce_count; - nc = ss.str(); - } + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } - auto qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else { - qop = "auth"; - } + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } - string response; - { - auto H = algo == "SHA-256" - ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : + detail::MD5; - auto A1 = username + ":" + auth.at("realm") + ":" + password; + auto A1 = username + ":" + auth.at("realm") + ":" + password; - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + - "\", response=\"" + response + "\""; + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + + "\", response=\"" + response + "\""; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } #endif inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { - auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; - if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); - auto s = res.get_header_value(auth_key); - auto pos = s.find(' '); - if (pos != std::string::npos) { - auto type = s.substr(0, pos); - if (type == "Basic") { - return false; - } else if (type == "Digest") { - s = s.substr(pos + 1); - auto beg = std::sregex_iterator(s.begin(), s.end(), re); - for (auto i = beg; i != std::sregex_iterator(); ++i) { - auto m = *i; - auto key = s.substr(static_cast(m.position(1)), - static_cast(m.length(1))); - auto val = m.length(2) > 0 - ? s.substr(static_cast(m.position(2)), - static_cast(m.length(2))) - : s.substr(static_cast(m.position(3)), - static_cast(m.length(3))); - auth[key] = val; + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) : + s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } } - return true; - } } - } - return false; + return false; } // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 inline std::string random_string(size_t length) { - auto randchar = []() -> char { - const char charset[] = "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[static_cast(rand()) % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; } class ContentProviderAdapter { public: - explicit ContentProviderAdapter( - ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) { + } - bool operator()(size_t offset, size_t, DataSink &sink) { - return content_provider_(offset, sink); - } + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } private: - ContentProviderWithoutLength content_provider_; + ContentProviderWithoutLength content_provider_; }; -} // namespace detail +} // namespace detail // Header utilities inline std::pair make_range_header(Ranges ranges) { - std::string field = "bytes="; - auto i = 0; - for (auto r : ranges) { - if (i != 0) { field += ", "; } - if (r.first != -1) { field += std::to_string(r.first); } - field += '-'; - if (r.second != -1) { field += std::to_string(r.second); } - i++; - } - return std::make_pair("Range", field); + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", field); } inline std::pair make_basic_authentication_header(const std::string &username, const std::string &password, bool is_proxy = false) { - auto field = "Basic " + detail::base64_encode(username + ":" + password); - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } inline std::pair make_bearer_token_authentication_header(const std::string &token, bool is_proxy = false) { - auto field = "Bearer " + token; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } // Request implementation inline bool Request::has_header(const char *key) const { - return detail::has_header(headers, key); + return detail::has_header(headers, key); } inline std::string Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Request::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Request::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Request::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline bool Request::has_param(const char *key) const { - return params.find(key) != params.end(); + return params.find(key) != params.end(); } inline std::string Request::get_param_value(const char *key, size_t id) const { - auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); } inline size_t Request::get_param_value_count(const char *key) const { - auto r = params.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline bool Request::is_multipart_form_data() const { - const auto &content_type = get_header_value("Content-Type"); - return !content_type.find("multipart/form-data"); + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); } inline bool Request::has_file(const char *key) const { - return files.find(key) != files.end(); + return files.find(key) != files.end(); } inline MultipartFormData Request::get_file_value(const char *key) const { - auto it = files.find(key); - if (it != files.end()) { return it->second; } - return MultipartFormData(); + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); } // Response implementation inline bool Response::has_header(const char *key) const { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline std::string Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Response::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Response::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Response::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline void Response::set_redirect(const char *url, int stat) { - if (!detail::has_crlf(url)) { - set_header("Location", url); - if (300 <= stat && stat < 400) { - this->status = stat; - } else { - this->status = 302; + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = 302; + } } - } } inline void Response::set_redirect(const std::string &url, int stat) { - set_redirect(url.c_str(), stat); + set_redirect(url.c_str(), stat); } inline void Response::set_content(const char *s, size_t n, const char *content_type) { - body.assign(s, n); - set_header("Content-Type", content_type); + body.assign(s, n); + set_header("Content-Type", content_type); } inline void Response::set_content(std::string s, const char *content_type) { - body = std::move(s); - set_header("Content-Type", content_type); + body = std::move(s); + set_header("Content-Type", content_type); } inline void Response::set_content_provider(size_t in_length, const char *content_type, ContentProvider provider, const std::function &resource_releaser) { - assert(in_length > 0); - set_header("Content-Type", content_type); - content_length_ = in_length; - content_provider_ = std::move(provider); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + assert(in_length > 0); + set_header("Content-Type", content_type); + content_length_ = in_length; + content_provider_ = std::move(provider); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_content_provider(const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_chunked_content_provider( const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = true; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = true; } // Rstream implementation inline ssize_t Stream::write(const char *ptr) { - return write(ptr, strlen(ptr)); + return write(ptr, strlen(ptr)); } inline ssize_t Stream::write(const std::string &s) { - return write(s.data(), s.size()); + return write(s.data(), s.size()); } -template -inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { - const auto bufsiz = 2048; - std::array buf; +template +inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { + const auto bufsiz = 2048; + std::array buf; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); + auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); #else - auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); #endif - if (sn <= 0) { return sn; } + if (sn <= 0) { + return sn; + } - auto n = static_cast(sn); + auto n = static_cast(sn); - if (n >= buf.size() - 1) { - std::vector glowable_buf(buf.size()); + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); - while (n >= glowable_buf.size() - 1) { - glowable_buf.resize(glowable_buf.size() * 2); + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), - glowable_buf.size() - 1, fmt, - args...)); + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); #else - n = static_cast( - snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); #endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); } - return write(&glowable_buf[0], n); - } else { - return write(buf.data(), n); - } } namespace detail { @@ -3697,75 +3996,88 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) {} + write_timeout_usec_(write_timeout_usec) { +} -inline SocketStream::~SocketStream() {} +inline SocketStream::~SocketStream() { +} inline bool SocketStream::is_readable() const { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SocketStream::is_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; } inline ssize_t SocketStream::read(char *ptr, size_t size) { - if (!is_readable()) { return -1; } + if (!is_readable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return recv(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return recv(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); #endif } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!is_writable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return send(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return send(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); #endif } inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - return detail::get_remote_ip_and_port(sock_, ip, port); + return detail::get_remote_ip_and_port(sock_, ip, port); } // Buffer stream implementation -inline bool BufferStream::is_readable() const { return true; } +inline bool BufferStream::is_readable() const { + return true; +} -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_writable() const { + return true; +} inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER <= 1900 - auto len_read = buffer._Copy_s(ptr, size, size, position); + auto len_read = buffer._Copy_s(ptr, size, size, position); #else - auto len_read = buffer.copy(ptr, size, position); + auto len_read = buffer.copy(ptr, size, position); #endif - position += static_cast(len_read); - return static_cast(len_read); + position += static_cast(len_read); + return static_cast(len_read); } inline ssize_t BufferStream::write(const char *ptr, size_t size) { - buffer.append(ptr, size); - return static_cast(size); + buffer.append(ptr, size); + return static_cast(size); } inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, - int & /*port*/) const {} + int & /*port*/) const { +} -inline const std::string &BufferStream::get_buffer() const { return buffer; } +inline const std::string &BufferStream::get_buffer() const { + return buffer; +} -} // namespace detail +} // namespace detail // HTTP server implementation inline Server::Server() @@ -3773,1182 +4085,1279 @@ inline Server::Server() [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }), svr_sock_(INVALID_SOCKET), is_running_(false) { #ifdef __linux__ - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif } -inline Server::~Server() {} +inline Server::~Server() { +} inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline bool Server::set_base_dir(const char *dir, const char *mount_point) { - return set_mount_point(mount_point, dir); + return set_mount_point(mount_point, dir); } inline bool Server::set_mount_point(const char *mount_point, const char *dir) { - if (detail::is_dir(dir)) { - std::string mnt = mount_point ? mount_point : "/"; - if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.emplace_back(mnt, dir); - return true; + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } } - } - return false; + return false; } inline bool Server::remove_mount_point(const char *mount_point) { - for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { - if (it->first == mount_point) { - base_dirs_.erase(it); - return true; + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } } - } - return false; + return false; } inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, const char *mime) { - file_extension_and_mimetype_map_[ext] = mime; + file_extension_and_mimetype_map_[ext] = mime; } inline void Server::set_file_request_handler(Handler handler) { - file_request_handler_ = std::move(handler); + file_request_handler_ = std::move(handler); } inline void Server::set_error_handler(Handler handler) { - error_handler_ = std::move(handler); + error_handler_ = std::move(handler); } -inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void Server::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } +inline void Server::set_logger(Logger logger) { + logger_ = std::move(logger); +} inline void Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { - expect_100_continue_handler_ = std::move(handler); + expect_100_continue_handler_ = std::move(handler); } inline void Server::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; + keep_alive_max_count_ = count; } inline void Server::set_keep_alive_timeout(time_t sec) { - keep_alive_timeout_sec_ = sec; + keep_alive_timeout_sec_ = sec; } inline void Server::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void Server::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void Server::set_idle_interval(time_t sec, time_t usec) { - idle_interval_sec_ = sec; - idle_interval_usec_ = usec; + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; } inline void Server::set_payload_max_length(size_t length) { - payload_max_length_ = length; + payload_max_length_ = length; } inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { - if (bind_internal(host, port, socket_flags) < 0) return false; - return true; + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; } inline int Server::bind_to_any_port(const char *host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + return bind_internal(host, 0, socket_flags); } -inline bool Server::listen_after_bind() { return listen_internal(); } +inline bool Server::listen_after_bind() { + return listen_internal(); +} inline bool Server::listen(const char *host, int port, int socket_flags) { - return bind_to_port(host, port, socket_flags) && listen_internal(); + return bind_to_port(host, port, socket_flags) && listen_internal(); } -inline bool Server::is_running() const { return is_running_; } +inline bool Server::is_running() const { + return is_running_; +} inline void Server::stop() { - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); - detail::shutdown_socket(sock); - detail::close_socket(sock); - } + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } } inline bool Server::parse_request_line(const char *s, Request &req) { - const static std::regex re( - "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " - "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[5]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3], false); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { detail::parse_query_text(m[4], req.params); } + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { + detail::parse_query_text(m[4], req.params); + } - return true; - } + return true; + } - return false; + return false; } inline bool Server::write_response(Stream &strm, bool close_connection, const Request &req, Response &res) { - assert(res.status != -1); + assert(res.status != -1); - if (400 <= res.status && error_handler_) { error_handler_(req, res); } - - detail::BufferStream bstrm; + if (400 <= res.status && error_handler_) { + error_handler_(req, res); + } - // Response line - if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, - detail::status_message(res.status))) { - return false; - } + detail::BufferStream bstrm; - // Headers - if (close_connection || req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } else { - std::stringstream ss; - ss << "timeout=" << keep_alive_timeout_sec_ - << ", max=" << keep_alive_max_count_; - res.set_header("Keep-Alive", ss.str()); - } + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } - if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { - res.set_header("Content-Type", "text/plain"); - } + // Headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::stringstream ss; + ss << "timeout=" << keep_alive_timeout_sec_ + << ", max=" << keep_alive_max_count_; + res.set_header("Keep-Alive", ss.str()); + } - if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { - res.set_header("Accept-Ranges", "bytes"); - } + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + res.set_header("Content-Type", "text/plain"); + } - std::string content_type; - std::string boundary; + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } - if (req.ranges.size() > 1) { - boundary = detail::make_multipart_data_boundary(); + std::string content_type; + std::string boundary; - auto it = res.headers.find("Content-Type"); - if (it != res.headers.end()) { - content_type = it->second; - res.headers.erase(it); - } + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); - res.headers.emplace("Content-Type", - "multipart/byteranges; boundary=" + boundary); - } + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } - auto type = detail::encoding_type(req, res); + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } - if (res.body.empty()) { - if (res.content_length_ > 0) { - size_t length = 0; - if (req.ranges.empty()) { - length = res.content_length_; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.content_length_); - res.set_header("Content-Range", content_range); - } else { - length = detail::get_multipart_ranges_data_length(req, res, boundary, - content_type); - } - res.set_header("Content-Length", std::to_string(length)); - } else { - if (res.content_provider_) { - if (res.is_chunked_content_provider) { - res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - } - } - } else { - res.set_header("Content-Length", "0"); - } - } - } else { - if (req.ranges.empty()) { - ; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.body.size(), 0); - auto offset = offsets.first; - auto length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.body.size()); - res.set_header("Content-Range", content_range); - res.body = res.body.substr(offset, length); + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } else { + res.set_header("Content-Length", "0"); + } + } } else { - res.body = - detail::make_multipart_ranges_data(req, res, boundary, content_type); - } + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } - if (type != detail::EncodingType::None) { - std::shared_ptr compressor; + if (type != detail::EncodingType::None) { + std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "gzip"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "gzip"); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "brotli"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "brotli"); #endif - } + } + + if (compressor) { + std::string compressed; - if (compressor) { - std::string compressed; + if (!compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + return false; + } - if (!compressor->compress(res.body.data(), res.body.size(), true, - [&](const char *data, size_t data_len) { - compressed.append(data, data_len); - return true; - })) { - return false; + res.body.swap(compressed); + } } - res.body.swap(compressed); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length); - } - - if (!detail::write_headers(bstrm, res, Headers())) { return false; } + if (!detail::write_headers(bstrm, res, Headers())) { + return false; + } - // Flush buffer - auto &data = bstrm.get_buffer(); - strm.write(data.data(), data.size()); + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); - // Body - auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!strm.write(res.body)) { ret = false; } - } else if (res.content_provider_) { - if (!write_content_with_provider(strm, req, res, boundary, - content_type)) { - ret = false; - } + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { + ret = false; + } + } else if (res.content_provider_) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + ret = false; + } + } } - } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return ret; + return ret; } inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - auto is_shutting_down = [this]() { - return this->svr_sock_ == INVALID_SOCKET; - }; - - if (res.content_length_ > 0) { - if (req.ranges.empty()) { - if (detail::write_content(strm, res.content_provider_, 0, - res.content_length_, is_shutting_down) < 0) { - return false; - } - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - auto length = offsets.second; - if (detail::write_content(strm, res.content_provider_, offset, length, - is_shutting_down) < 0) { - return false; - } + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider_, offset, length, + is_shutting_down) < 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, is_shutting_down)) { + return false; + } + } } else { - if (!detail::write_multipart_ranges_data( - strm, req, res, boundary, content_type, is_shutting_down)) { - return false; - } - } - } else { - if (res.is_chunked_content_provider) { - auto type = detail::encoding_type(req, res); + if (res.is_chunked_content_provider) { + auto type = detail::encoding_type(req, res); - std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + std::shared_ptr compressor; + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else { - compressor = std::make_shared(); - } - assert(compressor != nullptr); + } else { + compressor = std::make_shared(); + } + assert(compressor != nullptr); - if (detail::write_content_chunked(strm, res.content_provider_, - is_shutting_down, *compressor) < 0) { - return false; - } - } else { - if (detail::write_content_without_length(strm, res.content_provider_, - is_shutting_down) < 0) { - return false; - } + if (detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor) < 0) { + return false; + } + } else { + if (detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down) < 0) { + return false; + } + } } - } - return true; + return true; } inline bool Server::read_content(Stream &strm, Request &req, Response &res) { - MultipartFormDataMap::iterator cur; - if (read_content_core( - strm, req, res, - // Regular - [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } - req.body.append(buf, n); - return true; - }, - // Multipart - [&](const MultipartFormData &file) { - cur = req.files.emplace(file.name, file); - return true; - }, - [&](const char *buf, size_t n) { - auto &content = cur->second.content; - if (content.size() + n > content.max_size()) { return false; } - content.append(buf, n); - return true; - })) { - const auto &content_type = req.get_header_value("Content-Type"); - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); + MultipartFormDataMap::iterator cur; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + return true; } - return true; - } - return false; + return false; } inline bool Server::read_content_with_content_receiver( Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) { - return read_content_core(strm, req, res, receiver, multipart_header, - multipart_receiver); + return read_content_core(strm, req, res, receiver, multipart_header, + multipart_receiver); } inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader mulitpart_header, ContentReceiver multipart_receiver) { - detail::MultipartFormDataParser multipart_form_data_parser; - ContentReceiver out; + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } - if (req.is_multipart_form_data()) { - const auto &content_type = req.get_header_value("Content-Type"); - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary)) { - res.status = 400; - return false; - } - - multipart_form_data_parser.set_boundary(std::move(boundary)); - out = [&](const char *buf, size_t n) { - /* For debug - size_t pos = 0; - while (pos < n) { - auto read_size = std::min(1, n - pos); - auto ret = multipart_form_data_parser.parse( - buf + pos, read_size, multipart_receiver, mulitpart_header); - if (!ret) { return false; } - pos += read_size; - } - return true; - */ - return multipart_form_data_parser.parse(buf, n, multipart_receiver, - mulitpart_header); - }; - } else { - out = receiver; - } + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = std::min(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, mulitpart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { - return true; - } + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } - if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, - out, true)) { - return false; - } + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } - if (req.is_multipart_form_data()) { - if (!multipart_form_data_parser.is_valid()) { - res.status = 400; - return false; + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } } - } - return true; + return true; } inline bool Server::handle_file_request(Request &req, Response &res, bool head) { - for (const auto &kv : base_dirs_) { - const auto &mount_point = kv.first; - const auto &base_dir = kv.second; - - // Prefix match - if (!req.path.compare(0, mount_point.size(), mount_point)) { - std::string sub_path = "/" + req.path.substr(mount_point.size()); - if (detail::is_valid_path(sub_path)) { - auto path = base_dir + sub_path; - if (path.back() == '/') { path += "index.html"; } - - if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = - detail::find_content_type(path, file_extension_and_mimetype_map_); - if (type) { res.set_header("Content-Type", type); } - res.status = 200; - if (!head && file_request_handler_) { - file_request_handler_(req, res); - } - return true; - } - } - } - } - return false; + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.compare(0, mount_point.size(), mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { + res.set_header("Content-Type", type); + } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; } inline socket_t Server::create_server_socket(const char *host, int port, int socket_flags, SocketOptions socket_options) const { - return detail::create_socket( - host, port, socket_flags, tcp_nodelay_, socket_options, - [](socket_t sock, struct addrinfo &ai) -> bool { - if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }); + return detail::create_socket( + host, port, socket_flags, tcp_nodelay_, socket_options, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }); } inline int Server::bind_internal(const char *host, int port, int socket_flags) { - if (!is_valid()) { return -1; } - - svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); - if (svr_sock_ == INVALID_SOCKET) { return -1; } + if (!is_valid()) { + return -1; + } - if (port == 0) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - if (getsockname(svr_sock_, reinterpret_cast(&addr), - &addr_len) == -1) { - return -1; + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; } - if (addr.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - return ntohs(reinterpret_cast(&addr)->sin6_port); + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } } else { - return -1; + return port; } - } else { - return port; - } } inline bool Server::listen_internal() { - auto ret = true; - is_running_ = true; + auto ret = true; + is_running_ = true; - { - std::unique_ptr task_queue(new_task_queue()); + { + std::unique_ptr task_queue(new_task_queue()); - while (svr_sock_ != INVALID_SOCKET) { + while (svr_sock_ != INVALID_SOCKET) { #ifdef __linux__ - if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { #endif - auto val = detail::select_read(svr_sock_, idle_interval_sec_, - idle_interval_usec_); - if (val == 0) { // Timeout - task_queue->on_idle(); - continue; - } + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } #ifdef __linux__ - } + } #endif - socket_t sock = accept(svr_sock_, nullptr, nullptr); - - if (sock == INVALID_SOCKET) { - if (errno == EMFILE) { - // The per-process limit of open file descriptors has been reached. - // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; - } + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } #if __cplusplus > 201703L - task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); #else - task_queue->enqueue([=]() { process_and_close_socket(sock); }); + task_queue->enqueue([=]() { process_and_close_socket(sock); }); #endif - } + } - task_queue->shutdown(); - } + task_queue->shutdown(); + } - is_running_ = false; - return ret; + is_running_ = false; + return ret; } inline bool Server::routing(Request &req, Response &res, Stream &strm) { - // File handler - bool is_head_request = req.method == "HEAD"; - if ((req.method == "GET" || is_head_request) && - handle_file_request(req, res, is_head_request)) { - return true; - } + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } - if (detail::expect_content(req)) { - // Content reader handler - { - ContentReader reader( - [&](ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, receiver, - nullptr, nullptr); - }, - [&](MultipartContentHeader header, ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, nullptr, - header, receiver); - }); - - if (req.method == "POST") { - if (dispatch_request_for_content_reader( - req, res, reader, post_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PUT") { - if (dispatch_request_for_content_reader( - req, res, reader, put_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PATCH") { - if (dispatch_request_for_content_reader( - req, res, reader, patch_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "DELETE") { - if (dispatch_request_for_content_reader( - req, res, reader, delete_handlers_for_content_reader_)) { - return true; - } - } - } - - // Read content into `req.body` - if (!read_content(strm, req, res)) { return false; } - } - - // Regular handler - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } else if (req.method == "PATCH") { - return dispatch_request(req, res, patch_handlers_); - } - - res.status = 400; - return false; + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, receiver, + nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, reader, delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; } inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { - try { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + try { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; - } + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } catch (const std::exception &ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } - } catch (const std::exception &ex) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", ex.what()); - } catch (...) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", "UNKNOWN"); - } - return false; + return false; } inline bool Server::dispatch_request_for_content_reader( Request &req, Response &res, ContentReader content_reader, const HandlersForContentReader &handlers) { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res, content_reader); - return true; + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } } - } - return false; + return false; } inline bool Server::process_request(Stream &strm, bool close_connection, bool &connection_closed, const std::function &setup_request) { - std::array buf{}; + std::array buf{}; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - // Connection has been closed on client - if (!line_reader.getline()) { return false; } + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } - Request req; - Response res; + Request req; + Response res; - res.version = "HTTP/1.1"; + res.version = "HTTP/1.1"; - // Check if the request URI doesn't exceed the limit - if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = 414; - return write_response(strm, close_connection, req, res); - } + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, close_connection, req, res); + } - // Request line and headers - if (!parse_request_line(line_reader.ptr(), req) || - !detail::read_headers(strm, req.headers)) { - res.status = 400; - return write_response(strm, close_connection, req, res); - } + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, close_connection, req, res); + } - if (req.get_header_value("Connection") == "close") { - connection_closed = true; - } + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } - if (req.version == "HTTP/1.0" && - req.get_header_value("Connection") != "Keep-Alive") { - connection_closed = true; - } + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } - strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); - req.set_header("REMOTE_ADDR", req.remote_addr); - req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); - if (req.has_header("Range")) { - const auto &range_header_value = req.get_header_value("Range"); - if (!detail::parse_range_header(range_header_value, req.ranges)) { - // TODO: error + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } } - } - - if (setup_request) { setup_request(req); } - if (req.get_header_value("Expect") == "100-continue") { - auto status = 100; - if (expect_100_continue_handler_) { - status = expect_100_continue_handler_(req, res); + if (setup_request) { + setup_request(req); } - switch (status) { - case 100: - case 417: - strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, - detail::status_message(status)); - break; - default: return write_response(strm, close_connection, req, res); + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: + return write_response(strm, close_connection, req, res); + } } - } - // Rounting - if (routing(req, res, strm)) { - if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } - } else { - if (res.status == -1) { res.status = 404; } - } + // Rounting + if (routing(req, res, strm)) { + if (res.status == -1) { + res.status = req.ranges.empty() ? 200 : 206; + } + } else { + if (res.status == -1) { + res.status = 404; + } + } - return write_response(strm, close_connection, req, res); + return write_response(strm, close_connection, req, res); } -inline bool Server::is_valid() const { return true; } +inline bool Server::is_valid() const { + return true; +} inline bool Server::process_and_close_socket(socket_t sock) { - auto ret = detail::process_server_socket( - sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - nullptr); - }); + auto ret = detail::process_server_socket( + sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + nullptr); + }); - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; } // HTTP client implementation inline ClientImpl::ClientImpl(const std::string &host) - : ClientImpl(host, 80, std::string(), std::string()) {} + : ClientImpl(host, 80, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port) - : ClientImpl(host, port, std::string(), std::string()) {} + : ClientImpl(host, port, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), - client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + client_cert_path_(client_cert_path), client_key_path_(client_key_path) { +} -inline ClientImpl::~ClientImpl() { stop_core(); } +inline ClientImpl::~ClientImpl() { + stop_core(); +} -inline bool ClientImpl::is_valid() const { return true; } +inline bool ClientImpl::is_valid() const { + return true; +} -inline Error ClientImpl::get_last_error() const { return error_; } +inline Error ClientImpl::get_last_error() const { + return error_; +} inline socket_t ClientImpl::create_client_socket() const { - if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, interface_, error_); + } return detail::create_client_socket( - proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + host_.c_str(), port_, tcp_nodelay_, socket_options_, connection_timeout_sec_, connection_timeout_usec_, interface_, error_); - } - return detail::create_client_socket( - host_.c_str(), port_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, interface_, error_); } inline bool ClientImpl::create_and_connect_socket(Socket &socket) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } - socket.sock = sock; - return true; + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { + return false; + } + socket.sock = sock; + return true; } inline void ClientImpl::close_socket(Socket &socket, bool /*process_socket_ret*/) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - socket_.ssl = nullptr; + socket_.ssl = nullptr; #endif } inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { - std::array buf; + std::array buf; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - if (!line_reader.getline()) { return false; } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - - std::cmatch m; - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); - res.reason = std::string(m[3]); + if (!line_reader.getline()) { + return false; + } - // Ignore '100 Continue' - while (res.status == 100) { - if (!line_reader.getline()) { return false; } // CRLF - if (!line_reader.getline()) { return false; } // next response line + const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); res.reason = std::string(m[3]); - } - return true; + // Ignore '100 Continue' + while (res.status == 100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; } inline bool ClientImpl::send(const Request &req, Response &res) { - std::lock_guard request_mutex_guard(request_mutex_); + std::lock_guard request_mutex_guard(request_mutex_); - { - std::lock_guard guard(socket_mutex_); + { + std::lock_guard guard(socket_mutex_); - auto is_alive = false; - if (socket_.is_open()) { - is_alive = detail::select_write(socket_.sock, 0, 0) > 0; - if (!is_alive) { close_socket(socket_, false); } - } + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::select_write(socket_.sock, 0, 0) > 0; + if (!is_alive) { + close_socket(socket_, false); + } + } - if (!is_alive) { - if (!create_and_connect_socket(socket_)) { return false; } + if (!is_alive) { + if (!create_and_connect_socket(socket_)) { + return false; + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // TODO: refactoring - if (is_ssl()) { - auto &scli = static_cast(*this); - if (!proxy_host_.empty() && proxy_port_ != -1) { - bool success = false; - if (!scli.connect_with_proxy(socket_, res, success)) { - return success; - } - } - - if (!scli.initialize_ssl(socket_)) { return false; } - } + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + bool success = false; + if (!scli.connect_with_proxy(socket_, res, success)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_)) { + return false; + } + } #endif + } } - } - auto close_connection = !keep_alive_; + auto close_connection = !keep_alive_; - auto ret = process_socket(socket_, [&](Stream &strm) { - return handle_request(strm, req, res, close_connection); - }); + auto ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection); + }); - if (close_connection || !ret) { stop_core(); } + if (close_connection || !ret) { + stop_core(); + } - if (!ret) { - if (error_ == Error::Success) { error_ = Error::Unknown; } - } + if (!ret) { + if (error_ == Error::Success) { + error_ = Error::Unknown; + } + } - return ret; + return ret; } inline bool ClientImpl::handle_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - if (req.path.empty()) { - error_ = Error::Connection; - return false; - } + if (req.path.empty()) { + error_ = Error::Connection; + return false; + } - bool ret; + bool ret; - if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { - auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; - ret = process_request(strm, req2, res, close_connection); - } else { - ret = process_request(strm, req, res, close_connection); - } + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection); + } else { + ret = process_request(strm, req, res, close_connection); + } - if (!ret) { return false; } + if (!ret) { + return false; + } - if (300 < res.status && res.status < 400 && follow_location_) { - ret = redirect(req, res); - } + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if ((res.status == 401 || res.status == 407) && - req.authorization_count_ < 5) { - auto is_proxy = res.status == 407; - const auto &username = - is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; - const auto &password = - is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; - - if (!username.empty() && !password.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res, auth, is_proxy)) { - Request new_req = req; - new_req.authorization_count_ += 1; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - new_req.headers.erase(key); - new_req.headers.insert(detail::make_digest_authentication_header( - req, auth, new_req.authorization_count_, detail::random_string(10), - username, password, is_proxy)); - - Response new_res; - - ret = send(new_req, new_res); - if (ret) { res = new_res; } - } - } - } + if ((res.status == 401 || res.status == 407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { + res = new_res; + } + } + } + } #endif - return ret; + return ret; } inline bool ClientImpl::redirect(const Request &req, Response &res) { - if (req.redirect_count == 0) { - error_ = Error::ExceedRedirectCount; - return false; - } + if (req.redirect_count == 0) { + error_ = Error::ExceedRedirectCount; + return false; + } - auto location = detail::decode_url(res.get_header_value("location"), true); - if (location.empty()) { return false; } + auto location = detail::decode_url(res.get_header_value("location"), true); + if (location.empty()) { + return false; + } - const static std::regex re( - R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + const static std::regex re( + R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } - auto scheme = is_ssl() ? "https" : "http"; + auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - auto port_str = m[3].str(); - auto next_path = m[4].str(); + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); - auto next_port = port_; - if (!port_str.empty()) { - next_port = std::stoi(port_str); - } else if (!next_scheme.empty()) { - next_port = next_scheme == "https" ? 443 : 80; - } + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } - if (next_scheme.empty()) { next_scheme = scheme; } - if (next_host.empty()) { next_host = host_; } - if (next_path.empty()) { next_path = "/"; } + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } - if (next_scheme == scheme && next_host == host_ && next_port == port_) { - return detail::redirect(*this, req, res, next_path); - } else { - if (next_scheme == "https") { + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSLClient cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; #else - return false; + return false; #endif - } else { - ClientImpl cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + } else { + ClientImpl cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; + } } - } } inline bool ClientImpl::write_request(Stream &strm, const Request &req, bool close_connection) { - detail::BufferStream bstrm; - - // Request line - const auto &path = detail::encode_url(req.path); + detail::BufferStream bstrm; - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + // Request line + const auto &path = detail::encode_url(req.path); - // Additonal headers - Headers headers; - if (close_connection) { headers.emplace("Connection", "close"); } + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - if (!req.has_header("Host")) { - if (is_ssl()) { - if (port_ == 443) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } - } else { - if (port_ == 80) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } + // Additonal headers + Headers headers; + if (close_connection) { + headers.emplace("Connection", "close"); } - } - - if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } - if (!req.has_header("User-Agent")) { - headers.emplace("User-Agent", "cpp-httplib/0.7"); - } - - if (req.body.empty()) { - if (req.content_provider) { - auto length = std::to_string(req.content_length); - headers.emplace("Content-Length", length); - } else { - headers.emplace("Content-Length", "0"); - } - } else { - if (!req.has_header("Content-Type")) { - headers.emplace("Content-Type", "text/plain"); + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } } - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - headers.emplace("Content-Length", length); + if (!req.has_header("Accept")) { + headers.emplace("Accept", "*/*"); } - } - if (!basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - basic_auth_username_, basic_auth_password_, false)); - } + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.7"); + } - if (!proxy_basic_auth_username_.empty() && - !proxy_basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - proxy_basic_auth_username_, proxy_basic_auth_password_, true)); - } + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } - if (!bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - bearer_token_auth_token_, false)); - } + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } - if (!proxy_bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - proxy_bearer_token_auth_token_, true)); - } + if (!basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } - detail::write_headers(bstrm, req, headers); + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } - // Flush buffer - auto &data = bstrm.get_buffer(); - if (!detail::write_data(strm, data.data(), data.size())) { - error_ = Error::Write; - return false; - } + if (!bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } - // Body - if (req.body.empty()) { - if (req.content_provider) { - size_t offset = 0; - size_t end_offset = req.content_length; + if (!proxy_bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } - bool ok = true; + detail::write_headers(bstrm, req, headers); - DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - if (detail::write_data(strm, d, l)) { - offset += l; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error_ = Error::Write; + return false; + } - while (offset < end_offset) { - if (!req.content_provider(offset, end_offset - offset, data_sink)) { - error_ = Error::Canceled; - return false; - } - if (!ok) { - error_ = Error::Write; - return false; + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + bool ok = true; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + if (detail::write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (offset < end_offset) { + if (!req.content_provider(offset, end_offset - offset, data_sink)) { + error_ = Error::Canceled; + return false; + } + if (!ok) { + error_ = Error::Write; + return false; + } + } } - } + } else { + return detail::write_data(strm, req.body.data(), req.body.size()); } - } else { - return detail::write_data(strm, req.body.data(), req.body.size()); - } - return true; + return true; } inline std::shared_ptr ClientImpl::send_with_content_provider( @@ -4956,541 +5365,572 @@ inline std::shared_ptr ClientImpl::send_with_content_provider( const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type) { - Request req; - req.method = method; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = method; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - if (content_type) { req.headers.emplace("Content-Type", content_type); } + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } #ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_) { - detail::gzip_compressor compressor; - - if (content_provider) { - auto ok = true; - size_t offset = 0; - - DataSink data_sink; - data_sink.write = [&](const char *data, size_t data_len) { - if (ok) { - auto last = offset + data_len == content_length; - - auto ret = compressor.compress( - data, data_len, last, [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - }); - - if (ret) { - offset += data_len; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && true; }; - - while (ok && offset < content_length) { - if (!content_provider(offset, content_length - offset, data_sink)) { - error_ = Error::Canceled; - return nullptr; + if (compress_) { + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && true; }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error_ = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body.data(), body.size(), true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + return nullptr; + } } - } - } else { - if (!compressor.compress(body.data(), body.size(), true, - [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - })) { - return nullptr; - } - } - req.headers.emplace("Content-Encoding", "gzip"); - } else + req.headers.emplace("Content-Encoding", "gzip"); + } else #endif - { - if (content_provider) { - req.content_length = content_length; - req.content_provider = content_provider; - } else { - req.body = body; + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } } - } - auto res = std::make_shared(); + auto res = std::make_shared(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } inline bool ClientImpl::process_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - // Send request - if (!write_request(strm, req, close_connection)) { return false; } + // Send request + if (!write_request(strm, req, close_connection)) { + return false; + } - // Receive response and headers - if (!read_response_line(strm, res) || - !detail::read_headers(strm, res.headers)) { - error_ = Error::Read; - return false; - } + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + error_ = Error::Read; + return false; + } - if (req.response_handler) { - if (!req.response_handler(res)) { - error_ = Error::Canceled; - return false; + if (req.response_handler) { + if (!req.response_handler(res)) { + error_ = Error::Canceled; + return false; + } } - } - // Body - if (req.method != "HEAD" && req.method != "CONNECT") { - auto out = - req.content_receiver - ? static_cast([&](const char *buf, size_t n) { + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + auto out = + req.content_receiver ? static_cast([&](const char *buf, size_t n) { auto ret = req.content_receiver(buf, n); - if (!ret) { error_ = Error::Canceled; } + if (!ret) { + error_ = Error::Canceled; + } return ret; - }) - : static_cast([&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { return false; } - res.body.append(buf, n); + }) : + static_cast([&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress) { return true; - }); - - auto progress = [&](uint64_t current, uint64_t total) { - if (!req.progress) { return true; } - auto ret = req.progress(current, total); - if (!ret) { error_ = Error::Canceled; } - return ret; - }; + } + auto ret = req.progress(current, total); + if (!ret) { + error_ = Error::Canceled; + } + return ret; + }; - int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, progress, out, decompress_)) { - if (error_ != Error::Canceled) { error_ = Error::Read; } - return false; + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, progress, out, decompress_)) { + if (error_ != Error::Canceled) { + error_ = Error::Read; + } + return false; + } } - } - if (res.get_header_value("Connection") == "close" || - (res.version == "HTTP/1.0" && res.reason != "Connection established")) { - stop_core(); - } + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + stop_core(); + } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return true; + return true; } inline bool ClientImpl::process_socket(Socket &socket, std::function callback) { - return detail::process_client_socket(socket.sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, callback); + return detail::process_client_socket(socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, callback); } -inline bool ClientImpl::is_ssl() const { return false; } +inline bool ClientImpl::is_ssl() const { + return false; +} inline Result ClientImpl::Get(const char *path) { - return Get(path, Headers(), Progress()); + return Get(path, Headers(), Progress()); } inline Result ClientImpl::Get(const char *path, Progress progress) { - return Get(path, Headers(), std::move(progress)); + return Get(path, Headers(), std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers) { - return Get(path, headers, Progress()); + return Get(path, headers, Progress()); } inline Result ClientImpl::Get(const char *path, const Headers &headers, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver) { - return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return Get(path, headers, nullptr, std::move(content_receiver), nullptr); + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return Get(path, headers, nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - nullptr); + return Get(path, Headers(), std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, headers, std::move(response_handler), content_receiver, - nullptr); + return Get(path, headers, std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - progress); + return Get(path, Headers(), std::move(response_handler), content_receiver, + progress); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.response_handler = std::move(response_handler); - req.content_receiver = std::move(content_receiver); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Head(const char *path) { - return Head(path, Headers()); + return Head(path, Headers()); } inline Result ClientImpl::Head(const char *path, const Headers &headers) { - Request req; - req.method = "HEAD"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "HEAD"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Post(const char *path) { - return Post(path, std::string(), nullptr); + return Post(path, std::string(), nullptr); } inline Result ClientImpl::Post(const char *path, const std::string &body, const char *content_type) { - return Post(path, Headers(), body, content_type); + return Post(path, Headers(), body, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Params ¶ms) { - return Post(path, Headers(), params); + return Post(path, Headers(), params); } inline Result ClientImpl::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Post(path, Headers(), content_length, content_provider, content_type); + return Post(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Post(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Post(const char *path, const MultipartFormDataItems &items) { - return Post(path, Headers(), items); + return Post(path, Headers(), items); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - auto boundary = detail::make_multipart_data_boundary(); + auto boundary = detail::make_multipart_data_boundary(); - std::string body; + std::string body; - for (const auto &item : items) { - body += "--" + boundary + "\r\n"; - body += "Content-Disposition: form-data; name=\"" + item.name + "\""; - if (!item.filename.empty()) { - body += "; filename=\"" + item.filename + "\""; - } - body += "\r\n"; - if (!item.content_type.empty()) { - body += "Content-Type: " + item.content_type + "\r\n"; + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; } - body += "\r\n"; - body += item.content + "\r\n"; - } - body += "--" + boundary + "--\r\n"; + body += "--" + boundary + "--\r\n"; - std::string content_type = "multipart/form-data; boundary=" + boundary; - return Post(path, headers, body, content_type.c_str()); + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); } inline Result ClientImpl::Put(const char *path) { - return Put(path, std::string(), nullptr); + return Put(path, std::string(), nullptr); } inline Result ClientImpl::Put(const char *path, const std::string &body, const char *content_type) { - return Put(path, Headers(), body, content_type); + return Put(path, Headers(), body, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Put(path, Headers(), content_length, content_provider, content_type); + return Put(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, const Params ¶ms) { - return Put(path, Headers(), params); + return Put(path, Headers(), params); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Put(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Patch(const char *path, const std::string &body, const char *content_type) { - return Patch(path, Headers(), body, content_type); + return Patch(path, Headers(), body, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, body, 0, - nullptr, content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, body, 0, + nullptr, content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Patch(path, Headers(), content_length, content_provider, content_type); + return Patch(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Delete(const char *path) { - return Delete(path, Headers(), std::string(), nullptr); + return Delete(path, Headers(), std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const std::string &body, const char *content_type) { - return Delete(path, Headers(), body, content_type); + return Delete(path, Headers(), body, content_type); } inline Result ClientImpl::Delete(const char *path, const Headers &headers) { - return Delete(path, headers, std::string(), nullptr); + return Delete(path, headers, std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - Request req; - req.method = "DELETE"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; - - if (content_type) { req.headers.emplace("Content-Type", content_type); } - req.body = body; + Request req; + req.method = "DELETE"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; + + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } + req.body = body; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Options(const char *path) { - return Options(path, Headers()); + return Options(path, Headers()); } inline Result ClientImpl::Options(const char *path, const Headers &headers) { - Request req; - req.method = "OPTIONS"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "OPTIONS"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline size_t ClientImpl::is_socket_open() const { - std::lock_guard guard(socket_mutex_); - return socket_.is_open(); + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); } inline void ClientImpl::stop() { - stop_core(); - error_ = Error::Canceled; + stop_core(); + error_ = Error::Canceled; } inline void ClientImpl::stop_core() { - std::lock_guard guard(socket_mutex_); - if (socket_.is_open()) { - detail::shutdown_socket(socket_.sock); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - close_socket(socket_, true); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { + detail::shutdown_socket(socket_.sock); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + close_socket(socket_, true); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } } inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { - connection_timeout_sec_ = sec; - connection_timeout_usec_ = usec; + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; } inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void ClientImpl::set_basic_auth(const char *username, const char *password) { - basic_auth_username_ = username; - basic_auth_password_ = password; + basic_auth_username_ = username; + basic_auth_password_ = password; } inline void ClientImpl::set_bearer_token_auth(const char *token) { - bearer_token_auth_token_ = token; + bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_digest_auth(const char *username, const char *password) { - digest_auth_username_ = username; - digest_auth_password_ = password; + digest_auth_username_ = username; + digest_auth_password_ = password; } #endif -inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } +inline void ClientImpl::set_keep_alive(bool on) { + keep_alive_ = on; +} -inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } +inline void ClientImpl::set_follow_location(bool on) { + follow_location_ = on; +} inline void ClientImpl::set_default_headers(Headers headers) { - default_headers_ = std::move(headers); + default_headers_ = std::move(headers); } -inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void ClientImpl::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void ClientImpl::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void ClientImpl::set_compress(bool on) { compress_ = on; } +inline void ClientImpl::set_compress(bool on) { + compress_ = on; +} -inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } +inline void ClientImpl::set_decompress(bool on) { + decompress_ = on; +} -inline void ClientImpl::set_interface(const char *intf) { interface_ = intf; } +inline void ClientImpl::set_interface(const char *intf) { + interface_ = intf; +} inline void ClientImpl::set_proxy(const char *host, int port) { - proxy_host_ = host; - proxy_port_ = port; + proxy_host_ = host; + proxy_port_ = port; } inline void ClientImpl::set_proxy_basic_auth(const char *username, const char *password) { - proxy_basic_auth_username_ = username; - proxy_basic_auth_password_ = password; + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; } inline void ClientImpl::set_proxy_bearer_token_auth(const char *token) { - proxy_bearer_token_auth_token_ = token; + proxy_bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_proxy_digest_auth(const char *username, const char *password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } #endif inline void ClientImpl::set_logger(Logger logger) { - logger_ = std::move(logger); + logger_ = std::move(logger); } /* @@ -5499,66 +5939,66 @@ inline void ClientImpl::set_logger(Logger logger) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT namespace detail { -template +template inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } - if (ssl) { - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + if (ssl) { + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - return nullptr; + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + return nullptr; + } } - } - return ssl; + return ssl; } inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, bool process_socket_ret) { - if (process_socket_ret) { - SSL_shutdown(ssl); // shutdown only if not already closed by remote - } + if (process_socket_ret) { + SSL_shutdown(ssl); // shutdown only if not already closed by remote + } - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); } -template +template inline bool process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } #if OPENSSL_VERSION_NUMBER < 0x10100000L @@ -5566,49 +6006,51 @@ static std::shared_ptr> openSSL_locks_; class SSLThreadLocks { public: - SSLThreadLocks() { - openSSL_locks_ = - std::make_shared>(CRYPTO_num_locks()); - CRYPTO_set_locking_callback(locking_callback); - } + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } - ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + ~SSLThreadLocks() { + CRYPTO_set_locking_callback(nullptr); + } private: - static void locking_callback(int mode, int type, const char * /*file*/, - int /*line*/) { - auto &lk = (*openSSL_locks_)[static_cast(type)]; - if (mode & CRYPTO_LOCK) { - lk.lock(); - } else { - lk.unlock(); + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &lk = (*openSSL_locks_)[static_cast(type)]; + if (mode & CRYPTO_LOCK) { + lk.lock(); + } else { + lk.unlock(); + } } - } }; #endif class SSLInit { public: - SSLInit() { + SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - SSL_load_error_strings(); - SSL_library_init(); + SSL_load_error_strings(); + SSL_library_init(); #else - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); #endif - } + } - ~SSLInit() { + ~SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - ERR_free_strings(); + ERR_free_strings(); #endif - } + } private: #if OPENSSL_VERSION_NUMBER < 0x10100000L - SSLThreadLocks thread_init_; + SSLThreadLocks thread_init_; #endif }; @@ -5622,839 +6064,904 @@ inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec) { - { - timeval tv; - tv.tv_sec = static_cast(read_timeout_sec); - tv.tv_usec = static_cast(read_timeout_usec); + { + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } - { - timeval tv; - tv.tv_sec = static_cast(write_timeout_sec); - tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } + { + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } } -inline SSLSocketStream::~SSLSocketStream() {} +inline SSLSocketStream::~SSLSocketStream() { +} inline bool SSLSocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SSLSocketStream::is_writable() const { - return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > - 0; + return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > + 0; } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0 || is_readable()) { - return SSL_read(ssl_, ptr, static_cast(size)); - } - return -1; + if (SSL_pending(ssl_) > 0 || is_readable()) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } - return -1; + if (is_writable()) { + return SSL_write(ssl_, ptr, static_cast(size)); + } + return -1; } inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); + detail::get_remote_ip_and_port(sock_, ip, port); } static SSLInit sslinit_; -} // namespace detail +} // namespace detail // SSL HTTP server implementation inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, const char *client_ca_cert_dir_path) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - // if (client_ca_cert_file_path) { - // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); - // SSL_CTX_set_client_CA_list(ctx_, list); - // } - - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLServer::is_valid() const { return ctx_; } +inline bool SSLServer::is_valid() const { + return ctx_; +} inline bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, - [](SSL * /*ssl*/) { return true; }); - - if (ssl) { - auto ret = detail::process_server_socket_ssl( - ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [this, ssl](Stream &strm, bool close_connection, - bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - detail::ssl_delete(ctx_mutex_, ssl, ret); - return ret; - } + auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, + [](SSL * /*ssl*/) { return true; }); + + if (ssl) { + auto ret = detail::process_server_socket_ssl( + ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [this, ssl](Stream &strm, bool close_connection, + bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + detail::ssl_delete(ctx_mutex_, ssl, ret); + return ret; + } - detail::close_socket(sock); - return false; + detail::close_socket(sock); + return false; } // SSL HTTP client implementation inline SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} + : SSLClient(host, 443, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} + : SSLClient(host, port, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key) : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (client_cert != nullptr && client_key != nullptr) { - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLClient::is_valid() const { return ctx_; } +inline bool SSLClient::is_valid() const { + return ctx_; +} inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } - if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } + if (ca_cert_file_path) { + ca_cert_file_path_ = ca_cert_file_path; + } + if (ca_cert_dir_path) { + ca_cert_dir_path_ = ca_cert_dir_path; + } } inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { ca_cert_store_ = ca_cert_store; } + if (ca_cert_store) { + ca_cert_store_ = ca_cert_store; + } } inline void SSLClient::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; + server_certificate_verification_ = enabled; } inline long SSLClient::get_openssl_verify_result() const { - return verify_result_; + return verify_result_; } -inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } +inline SSL_CTX *SSLClient::ssl_context() const { + return ctx_; +} inline bool SSLClient::create_and_connect_socket(Socket &socket) { - return is_valid() && ClientImpl::create_and_connect_socket(socket); + return is_valid() && ClientImpl::create_and_connect_socket(socket); } inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success) { - success = true; - Response res2; - - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = host_and_port_; - return process_request(strm, req2, res2, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - - if (res2.status == 407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res2, auth, true)) { - Response res3; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = host_and_port_; - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - return process_request(strm, req3, res3, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - } - } else { - res = res2; - return false; + success = true; + Response res2; + + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + } + } else { + res = res2; + return false; + } } - } - return true; + return true; } inline bool SSLClient::load_certs() { - bool ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - ret = false; - } - } else if (ca_cert_store_ != nullptr) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { - SSL_CTX_set_cert_store(ctx_, ca_cert_store_); - } - } else { + bool ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else if (ca_cert_store_ != nullptr) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { + SSL_CTX_set_cert_store(ctx_, ca_cert_store_); + } + } else { #ifdef _WIN32 - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); #else SSL_CTX_set_default_verify_paths(ctx_); #endif - } - }); + } + }); - return ret; + return ret; } inline bool SSLClient::initialize_ssl(Socket &socket) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl) { - if (server_certificate_verification_) { - if (!load_certs()) { - error_ = Error::SSLLoadingCerts; - return false; - } - SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); - } + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (server_certificate_verification_) { + if (!load_certs()) { + error_ = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); + } - if (SSL_connect(ssl) != 1) { - error_ = Error::SSLConnection; - return false; - } + if (SSL_connect(ssl) != 1) { + error_ = Error::SSLConnection; + return false; + } - if (server_certificate_verification_) { - verify_result_ = SSL_get_verify_result(ssl); + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); - if (verify_result_ != X509_V_OK) { - error_ = Error::SSLServerVerification; - return false; - } + if (verify_result_ != X509_V_OK) { + error_ = Error::SSLServerVerification; + return false; + } - auto server_cert = SSL_get_peer_certificate(ssl); + auto server_cert = SSL_get_peer_certificate(ssl); - if (server_cert == nullptr) { - error_ = Error::SSLServerVerification; - return false; - } + if (server_cert == nullptr) { + error_ = Error::SSLServerVerification; + return false; + } - if (!verify_host(server_cert)) { - X509_free(server_cert); - error_ = Error::SSLServerVerification; - return false; - } - X509_free(server_cert); - } + if (!verify_host(server_cert)) { + X509_free(server_cert); + error_ = Error::SSLServerVerification; + return false; + } + X509_free(server_cert); + } - return true; - }, - [&](SSL *ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - return true; - }); + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }); - if (ssl) { - socket.ssl = ssl; - return true; - } + if (ssl) { + socket.ssl = ssl; + return true; + } - close_socket(socket, false); - return false; + close_socket(socket, false); + return false; } inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); - socket_.ssl = nullptr; - } + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); + socket_.ssl = nullptr; + } } inline bool SSLClient::process_socket(Socket &socket, std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, callback); + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, callback); } -inline bool SSLClient::is_ssl() const { return true; } +inline bool SSLClient::is_ssl() const { + return true; +} inline bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" + /* Quote from RFC2818 section 3.1 "Server Identity" - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); } inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; + auto ret = false; - auto type = GEN_DNS; + auto type = GEN_DNS; - struct in6_addr addr6; - struct in_addr addr; - size_t addr_len = 0; + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; #ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); - } + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_mached = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (val->type == type) { - auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); - auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); - - if (strlen(name) == name_len) { - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || - !memcmp(&addr, name, addr_len)) { - ip_mached = true; + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } } - break; - } } - } - } - if (dsn_matched || ip_mached) { ret = true; } - } + if (dsn_matched || ip_mached) { + ret = true; + } + } - GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); - return ret; + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + return ret; } inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); + const auto subject_name = X509_get_subject_name(server_cert); - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } } - } - return false; + return false; } inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { - if (host_.size() == pattern_len && host_ == pattern) { return true; } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(std::string(b, e)); - }); + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); - if (host_components_.size() != pattern_components.size()) { return false; } + if (host_components_.size() != pattern_components.size()) { + return false; + } - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (p != h && p != "*") { - auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && - !p.compare(0, p.size() - 1, h)); - if (!partial_match) { return false; } + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; } - ++itr; - } - return true; + return true; } #endif // Universal client implementation inline Client::Client(const char *scheme_host_port) - : Client(scheme_host_port, std::string(), std::string()) {} + : Client(scheme_host_port, std::string(), std::string()) { +} inline Client::Client(const char *scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); + const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); - std::cmatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); + std::cmatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (!scheme.empty() && (scheme != "http" && scheme != "https")) { + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else - if (!scheme.empty() && scheme != "http") { + if (!scheme.empty() && scheme != "http") { #endif - return; - } + return; + } - auto is_ssl = scheme == "https"; + auto is_ssl = scheme == "https"; - auto host = m[2].str(); + auto host = m[2].str(); - auto port_str = m[3].str(); - auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + auto port_str = m[3].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); - if (is_ssl) { + if (is_ssl) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); - is_ssl_ = is_ssl; + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; #endif + } else { + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + } } else { - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); + cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, + client_key_path); } - } else { - cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, - client_key_path); - } } inline Client::Client(const std::string &host, int port) - : cli_(std::make_shared(host, port)) {} + : cli_(std::make_shared(host, port)) { +} inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : cli_(std::make_shared(host, port, client_cert_path, - client_key_path)) {} + client_key_path)) { +} -inline Client::~Client() {} +inline Client::~Client() { +} inline bool Client::is_valid() const { - return cli_ != nullptr && cli_->is_valid(); + return cli_ != nullptr && cli_->is_valid(); } -inline Result Client::Get(const char *path) { return cli_->Get(path); } +inline Result Client::Get(const char *path) { + return cli_->Get(path); +} inline Result Client::Get(const char *path, const Headers &headers) { - return cli_->Get(path, headers); + return cli_->Get(path, headers); } inline Result Client::Get(const char *path, Progress progress) { - return cli_->Get(path, progress); + return cli_->Get(path, progress); } inline Result Client::Get(const char *path, const Headers &headers, Progress progress) { - return cli_->Get(path, headers, progress); + return cli_->Get(path, headers, progress); } inline Result Client::Get(const char *path, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(content_receiver)); + return cli_->Get(path, std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(content_receiver)); + return cli_->Get(path, headers, std::move(content_receiver)); } inline Result Client::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, std::move(content_receiver), - std::move(progress)); + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, response_handler, content_receiver, progress); + return cli_->Get(path, headers, response_handler, content_receiver, progress); } -inline Result Client::Head(const char *path) { return cli_->Head(path); } +inline Result Client::Head(const char *path) { + return cli_->Head(path); +} inline Result Client::Head(const char *path, const Headers &headers) { - return cli_->Head(path, headers); + return cli_->Head(path, headers); } -inline Result Client::Post(const char *path) { return cli_->Post(path); } +inline Result Client::Post(const char *path) { + return cli_->Post(path); +} inline Result Client::Post(const char *path, const std::string &body, const char *content_type) { - return cli_->Post(path, body, content_type); + return cli_->Post(path, body, content_type); } inline Result Client::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Post(path, headers, body, content_type); + return cli_->Post(path, headers, body, content_type); } inline Result Client::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, content_length, content_provider, content_type); + return cli_->Post(path, content_length, content_provider, content_type); } inline Result Client::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, headers, content_length, content_provider, - content_type); + return cli_->Post(path, headers, content_length, content_provider, + content_type); } inline Result Client::Post(const char *path, const Params ¶ms) { - return cli_->Post(path, params); + return cli_->Post(path, params); } inline Result Client::Post(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Post(path, headers, params); + return cli_->Post(path, headers, params); } inline Result Client::Post(const char *path, const MultipartFormDataItems &items) { - return cli_->Post(path, items); + return cli_->Post(path, items); } inline Result Client::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - return cli_->Post(path, headers, items); + return cli_->Post(path, headers, items); +} +inline Result Client::Put(const char *path) { + return cli_->Put(path); } -inline Result Client::Put(const char *path) { return cli_->Put(path); } inline Result Client::Put(const char *path, const std::string &body, const char *content_type) { - return cli_->Put(path, body, content_type); + return cli_->Put(path, body, content_type); } inline Result Client::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Put(path, headers, body, content_type); + return cli_->Put(path, headers, body, content_type); } inline Result Client::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, content_length, content_provider, content_type); + return cli_->Put(path, content_length, content_provider, content_type); } inline Result Client::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, headers, content_length, content_provider, - content_type); + return cli_->Put(path, headers, content_length, content_provider, + content_type); } inline Result Client::Put(const char *path, const Params ¶ms) { - return cli_->Put(path, params); + return cli_->Put(path, params); } inline Result Client::Put(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Put(path, headers, params); + return cli_->Put(path, headers, params); } inline Result Client::Patch(const char *path, const std::string &body, const char *content_type) { - return cli_->Patch(path, body, content_type); + return cli_->Patch(path, body, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Patch(path, headers, body, content_type); + return cli_->Patch(path, headers, body, content_type); } inline Result Client::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, content_length, content_provider, content_type); + return cli_->Patch(path, content_length, content_provider, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, headers, content_length, content_provider, - content_type); + return cli_->Patch(path, headers, content_length, content_provider, + content_type); +} +inline Result Client::Delete(const char *path) { + return cli_->Delete(path); } -inline Result Client::Delete(const char *path) { return cli_->Delete(path); } inline Result Client::Delete(const char *path, const std::string &body, const char *content_type) { - return cli_->Delete(path, body, content_type); + return cli_->Delete(path, body, content_type); } inline Result Client::Delete(const char *path, const Headers &headers) { - return cli_->Delete(path, headers); + return cli_->Delete(path, headers); } inline Result Client::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Delete(path, headers, body, content_type); + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Options(const char *path) { + return cli_->Options(path); } -inline Result Client::Options(const char *path) { return cli_->Options(path); } inline Result Client::Options(const char *path, const Headers &headers) { - return cli_->Options(path, headers); + return cli_->Options(path, headers); } inline bool Client::send(const Request &req, Response &res) { - return cli_->send(req, res); + return cli_->send(req, res); } -inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } +inline size_t Client::is_socket_open() const { + return cli_->is_socket_open(); +} -inline void Client::stop() { cli_->stop(); } +inline void Client::stop() { + cli_->stop(); +} inline void Client::set_default_headers(Headers headers) { - cli_->set_default_headers(std::move(headers)); + cli_->set_default_headers(std::move(headers)); } -inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } +inline void Client::set_tcp_nodelay(bool on) { + cli_->set_tcp_nodelay(on); +} inline void Client::set_socket_options(SocketOptions socket_options) { - cli_->set_socket_options(socket_options); + cli_->set_socket_options(socket_options); } inline void Client::set_connection_timeout(time_t sec, time_t usec) { - cli_->set_connection_timeout(sec, usec); + cli_->set_connection_timeout(sec, usec); } inline void Client::set_read_timeout(time_t sec, time_t usec) { - cli_->set_read_timeout(sec, usec); + cli_->set_read_timeout(sec, usec); } inline void Client::set_write_timeout(time_t sec, time_t usec) { - cli_->set_write_timeout(sec, usec); + cli_->set_write_timeout(sec, usec); } inline void Client::set_basic_auth(const char *username, const char *password) { - cli_->set_basic_auth(username, password); + cli_->set_basic_auth(username, password); } inline void Client::set_bearer_token_auth(const char *token) { - cli_->set_bearer_token_auth(token); + cli_->set_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_digest_auth(const char *username, const char *password) { - cli_->set_digest_auth(username, password); + cli_->set_digest_auth(username, password); } #endif -inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_keep_alive(bool on) { + cli_->set_keep_alive(on); +} inline void Client::set_follow_location(bool on) { - cli_->set_follow_location(on); + cli_->set_follow_location(on); } -inline void Client::set_compress(bool on) { cli_->set_compress(on); } +inline void Client::set_compress(bool on) { + cli_->set_compress(on); +} -inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } +inline void Client::set_decompress(bool on) { + cli_->set_decompress(on); +} inline void Client::set_interface(const char *intf) { - cli_->set_interface(intf); + cli_->set_interface(intf); } inline void Client::set_proxy(const char *host, int port) { - cli_->set_proxy(host, port); + cli_->set_proxy(host, port); } inline void Client::set_proxy_basic_auth(const char *username, const char *password) { - cli_->set_proxy_basic_auth(username, password); + cli_->set_proxy_basic_auth(username, password); } inline void Client::set_proxy_bearer_token_auth(const char *token) { - cli_->set_proxy_bearer_token_auth(token); + cli_->set_proxy_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_proxy_digest_auth(const char *username, const char *password) { - cli_->set_proxy_digest_auth(username, password); + cli_->set_proxy_digest_auth(username, password); } #endif -inline void Client::set_logger(Logger logger) { cli_->set_logger(logger); } +inline void Client::set_logger(Logger logger) { + cli_->set_logger(logger); +} #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline Client &Client::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, - ca_cert_dir_path); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, + ca_cert_dir_path); + } + return *this; } inline Client &Client::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } + return *this; } inline Client &Client::enable_server_certificate_verification(bool enabled) { - if (is_ssl_) { - static_cast(*cli_).enable_server_certificate_verification( - enabled); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).enable_server_certificate_verification( + enabled); + } + return *this; } inline long Client::get_openssl_verify_result() const { - if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); - } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } inline SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } - return nullptr; + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; } #endif // ---------------------------------------------------------------------------- -} // namespace httplib +} // namespace httplib -#endif // CPPHTTPLIB_HTTPLIB_H +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/src/bb/dnn/onnxruntime_c.h b/src/bb/dnn/onnxruntime_c.h index d8c38448..7456066f 100644 --- a/src/bb/dnn/onnxruntime_c.h +++ b/src/bb/dnn/onnxruntime_c.h @@ -171,7 +171,7 @@ ORT_RUNTIME_CLASS(Env); ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success ORT_RUNTIME_CLASS(MemoryInfo); ORT_RUNTIME_CLASS(IoBinding); -ORT_RUNTIME_CLASS(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) +ORT_RUNTIME_CLASS(Session); // Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); @@ -193,8 +193,9 @@ typedef OrtStatus *OrtStatusPtr; // __VA_ARGS__ on Windows and Linux are different #define ORT_API(RETURN_TYPE, NAME, ...) ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION -#define ORT_API_STATUS(NAME, ...) \ - ORT_EXPORT _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT +#define ORT_API_STATUS(NAME, ...) \ + ORT_EXPORT _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ + NO_EXCEPTION ORT_MUST_USE_RESULT // XXX: Unfortunately, SAL annotations are known to not work with function pointers #define ORT_API2_STATUS(NAME, ...) \ @@ -260,7 +261,7 @@ typedef enum OrtAllocatorType { /** * memory types for allocator, exec provider specific types should be extended in each provider * Whenever this struct is updated, please also update the MakeKey function in onnxruntime/core/framework/execution_provider.cc -*/ + */ typedef enum OrtMemType { OrtMemTypeCPUInput = -2, // Any CPU memory used by non-CPU execution provider OrtMemTypeCPUOutput = -1, // CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED @@ -281,16 +282,16 @@ typedef struct OrtApiBase OrtApiBase; struct OrtApi { /** -* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus -*/ + * \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus + */ OrtStatus *(ORT_API_CALL *CreateStatus)(OrtErrorCode code, _In_ const char *msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; OrtErrorCode(ORT_API_CALL *GetErrorCode)(_In_ const OrtStatus *status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** - * \param status must not be NULL - * \return The error message inside the `status`. Do not free the returned value. - */ + * \param status must not be NULL + * \return The error message inside the `status`. Do not free the returned value. + */ const char *(ORT_API_CALL *GetErrorMessage)(_In_ const OrtStatus *status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** @@ -299,8 +300,8 @@ struct OrtApi { ORT_API2_STATUS(CreateEnv, OrtLoggingLevel default_logging_level, _In_ const char *logid, _Outptr_ OrtEnv **out); /** - * \param out Should be freed by `OrtReleaseEnv` after use - */ + * \param out Should be freed by `OrtReleaseEnv` after use + */ ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void *logger_param, OrtLoggingLevel default_warning_level, _In_ const char *logid, _Outptr_ OrtEnv **out); @@ -326,8 +327,8 @@ struct OrtApi { _Inout_updates_all_(output_names_len) OrtValue **output); /** - * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use - */ + * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use + */ ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions **options); // Set filepath to save optimized model after graph level transformations. @@ -390,13 +391,13 @@ struct OrtApi { /* * Add custom ops to the OrtCustomOpDomain * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released - */ + */ ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain *custom_op_domain, _In_ OrtCustomOp *op); /* * Add a custom op domain to the OrtSessionOptions * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released - */ + */ ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions *options, _In_ OrtCustomOpDomain *custom_op_domain); /* @@ -405,45 +406,45 @@ struct OrtApi { * It then passes in the provided session options to this function along with the api base. * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in * session options are destroyed, or if an error occurs and it is non null. - */ + */ ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions *options, _In_ const char *library_path, void **library_handle); /** - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session: - * OrtSessionOptionsAppendExecutionProvider_CPU - * OrtSessionOptionsAppendExecutionProvider_CUDA - * OrtSessionOptionsAppendExecutionProvider_ - * The order they are called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. - */ + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session: + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ + * The order they are called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + */ ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession *sess, _Out_ size_t *out); ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession *sess, _Out_ size_t *out); ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession *sess, _Out_ size_t *out); /** - * \param out should be freed by OrtReleaseTypeInfo after use - */ + * \param out should be freed by OrtReleaseTypeInfo after use + */ ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession *sess, size_t index, _Outptr_ OrtTypeInfo **type_info); /** - * \param out should be freed by OrtReleaseTypeInfo after use - */ + * \param out should be freed by OrtReleaseTypeInfo after use + */ ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession *sess, size_t index, _Outptr_ OrtTypeInfo **type_info); /** - * \param out should be freed by OrtReleaseTypeInfo after use - */ + * \param out should be freed by OrtReleaseTypeInfo after use + */ ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession *sess, size_t index, _Outptr_ OrtTypeInfo **type_info); /** - * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. - */ + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + */ ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession *sess, size_t index, _Inout_ OrtAllocator *allocator, _Outptr_ char **value); ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession *sess, size_t index, _Inout_ OrtAllocator *allocator, @@ -452,8 +453,8 @@ struct OrtApi { _Inout_ OrtAllocator *allocator, _Outptr_ char **value); /** - * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use - */ + * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use + */ ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions **out); ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions *options, int value); @@ -471,25 +472,25 @@ struct OrtApi { ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions *options); /** - * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value - * \param out Should be freed by calling OrtReleaseValue - * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - */ + * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value + * \param out Should be freed by calling OrtReleaseValue + * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + */ ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator *allocator, _In_ const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue **out); /** - * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. - * p_data is owned by caller. OrtReleaseValue won't release p_data. - * \param out Should be freed by calling OrtReleaseValue - */ + * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. + * p_data is owned by caller. OrtReleaseValue won't release p_data. + * \param out Should be freed by calling OrtReleaseValue + */ ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo *info, _Inout_ void *p_data, size_t p_data_len, _In_ const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue **out); /** - * \Sets *out to 1 iff an OrtValue is a tensor, 0 otherwise - */ + * \Sets *out to 1 iff an OrtValue is a tensor, 0 otherwise + */ ORT_API2_STATUS(IsTensor, _In_ const OrtValue *value, _Out_ int *out); // This function doesn't work with string tensor @@ -536,10 +537,10 @@ struct OrtApi { ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo *, enum ONNXTensorElementDataType type); /** - * \param info Created from CreateTensorTypeAndShapeInfo() function - * \param dim_values An array with length of `dim_count`. Its elements can contain negative values. - * \param dim_count length of dim_values - */ + * \param info Created from CreateTensorTypeAndShapeInfo() function + * \param dim_values An array with length of `dim_count`. Its elements can contain negative values. + * \param dim_count length of dim_values + */ ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo *info, _In_ const int64_t *dim_values, size_t dim_count); ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo *, @@ -551,26 +552,26 @@ struct OrtApi { _Out_writes_all_(dim_params_length) const char *dim_params[], size_t dim_params_length); /** - * Return the number of elements specified by the tensor shape. - * Return a negative value if unknown (i.e., any dimension is negative.) - * e.g. - * [] -> 1 - * [1,3,4] -> 12 - * [2,0,4] -> 0 - * [-1,3,4] -> -1 - */ + * Return the number of elements specified by the tensor shape. + * Return a negative value if unknown (i.e., any dimension is negative.) + * e.g. + * [] -> 1 + * [1,3,4] -> 12 + * [2,0,4] -> 0 + * [-1,3,4] -> -1 + */ ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo *info, _Out_ size_t *out); /** - * \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use - */ + * \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use + */ ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue *value, _Outptr_ OrtTensorTypeAndShapeInfo **out); /** - * Get the type information of an OrtValue - * \param value - * \param out The returned value should be freed by OrtReleaseTypeInfo after use - */ + * Get the type information of an OrtValue + * \param value + * \param out The returned value should be freed by OrtReleaseTypeInfo after use + */ ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue *value, _Outptr_result_maybenull_ OrtTypeInfo **out); ORT_API2_STATUS(GetValueType, _In_ const OrtValue *value, _Out_ enum ONNXType *out); @@ -579,20 +580,20 @@ struct OrtApi { enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo **out); /** - * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0. - */ + * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0. + */ ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo **out); /** - * Test if two memory info are equal - * \Sets 'out' to 0 if equal, -1 if not equal - */ + * Test if two memory info are equal + * \Sets 'out' to 0 if equal, -1 if not equal + */ ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo *info1, _In_ const OrtMemoryInfo *info2, _Out_ int *out); /** - * Do not free the returned value - */ + * Do not free the returned value + */ ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo *ptr, _Out_ const char **out); ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo *ptr, _Out_ int *out); ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo *ptr, _Out_ OrtMemType *out); @@ -613,52 +614,52 @@ struct OrtApi { _In_ int64_t dim_value); /** - * APIs to support non-tensor types - map and sequence. - * Currently only the following types are supported - * Note: the following types should be kept in sync with data_types.h - * Map types - * ========= - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * - * Sequence types - * ============== - * std::vector - * std::vector - * std::vector - * std::vector - * std::vector> - * std::vector - */ + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ /** - * If input OrtValue represents a map, you need to retrieve the keys and values - * separately. Use index=0 to retrieve keys and index=1 to retrieve values. - * If input OrtValue represents a sequence, use index to retrieve the index'th element - * of the sequence. - */ + * If input OrtValue represents a map, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If input OrtValue represents a sequence, use index to retrieve the index'th element + * of the sequence. + */ ORT_API2_STATUS(GetValue, _In_ const OrtValue *value, int index, _Inout_ OrtAllocator *allocator, _Outptr_ OrtValue **out); /** - * Returns 2 for type map and N for sequence where N is the number of elements - * in the sequence. - */ + * Returns 2 for type map and N for sequence where N is the number of elements + * in the sequence. + */ ORT_API2_STATUS(GetValueCount, _In_ const OrtValue *value, _Out_ size_t *out); /** - * To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues - * representing keys and values. - * To construct a sequence, use num_values = N where N is the number of the elements in the - * sequence. 'in' should be an arrary of N OrtValues. - * \value_type should be either map or sequence. - */ + * To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues + * representing keys and values. + * To construct a sequence, use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an arrary of N OrtValues. + * \value_type should be either map or sequence. + */ ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue *const *in, size_t num_values, enum ONNXType value_type, _Outptr_ OrtValue **out); @@ -714,7 +715,7 @@ struct OrtApi { ORT_CLASS_RELEASE(Env); ORT_CLASS_RELEASE(Status); // nullptr for Status* indicates success ORT_CLASS_RELEASE(MemoryInfo); - ORT_CLASS_RELEASE(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) + ORT_CLASS_RELEASE(Session); // Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) ORT_CLASS_RELEASE(Value); ORT_CLASS_RELEASE(RunOptions); ORT_CLASS_RELEASE(TypeInfo); @@ -727,60 +728,60 @@ struct OrtApi { // Version 2 - In development, feel free to add/remove/rearrange here /** - * GetDenotationFromTypeInfo - * This api augments OrtTypeInfo to return denotations on the type. - * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. - */ + * GetDenotationFromTypeInfo + * This api augments OrtTypeInfo to return denotations on the type. + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + */ ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo *, _Out_ const char **const denotation, _Out_ size_t *len); // OrtTypeInfo Casting methods /** - * CastTypeInfoToMapTypeInfo - * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map. - * The OrtMapTypeInfo has additional information about the map's key type and value type. - * This is used by WinML to support model reflection APIs. - * This is used by WinML to support model reflection APIs. - * - * Don't free the 'out' value - */ + * CastTypeInfoToMapTypeInfo + * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * This is used by WinML to support model reflection APIs. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo *type_info, _Outptr_result_maybenull_ const OrtMapTypeInfo **out); /** - * CastTypeInfoToSequenceTypeInfo - * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence. - * The OrtSequenceTypeInfo has additional information about the sequence's element type. - * This is used by WinML to support model reflection APIs. - * - * Don't free the 'out' value - */ + * CastTypeInfoToSequenceTypeInfo + * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence. + * The OrtSequenceTypeInfo has additional information about the sequence's element type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo *type_info, _Outptr_result_maybenull_ const OrtSequenceTypeInfo **out); // OrtMapTypeInfo Accessors /** - * GetMapKeyType - * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. - * This is used by WinML to support model reflection APIs. - */ + * GetMapKeyType + * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. + * This is used by WinML to support model reflection APIs. + */ ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo *map_type_info, _Out_ enum ONNXTensorElementDataType *out); /** - * GetMapValueType - * This api augments get the value type of a map. - */ + * GetMapValueType + * This api augments get the value type of a map. + */ ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo *map_type_info, _Outptr_ OrtTypeInfo **type_info); // OrtSequenceTypeInfo Accessors /** - * GetSequenceElementType - * This api augments get the element type of a sequence. - * This is used by WinML to support model reflection APIs. - */ + * GetSequenceElementType + * This api augments get the element type of a sequence. + * This is used by WinML to support model reflection APIs. + */ ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo *sequence_type_info, _Outptr_ OrtTypeInfo **type_info); @@ -788,20 +789,20 @@ struct OrtApi { ORT_CLASS_RELEASE(SequenceTypeInfo); /** - * \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. - * Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling() - * on the SessionOptions instance used to create the session. - */ + * \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling() + * on the SessionOptions instance used to create the session. + */ ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession *sess, _Inout_ OrtAllocator *allocator, _Outptr_ char **out); /** - * \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use. - */ + * \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use. + */ ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession *sess, _Outptr_ OrtModelMetadata **out); /** - * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. - */ + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + */ ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata *model_metadata, _Inout_ OrtAllocator *allocator, _Outptr_ char **value); ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata *model_metadata, @@ -811,9 +812,9 @@ struct OrtApi { ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata *model_metadata, _Inout_ OrtAllocator *allocator, _Outptr_ char **value); /** - * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. - * 'value' will be a nullptr if the given key is not found in the custom metadata map. - */ + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * 'value' will be a nullptr if the given key is not found in the custom metadata map. + */ ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata *model_metadata, _Inout_ OrtAllocator *allocator, _In_ const char *key, _Outptr_result_maybenull_ char **value); @@ -822,19 +823,19 @@ struct OrtApi { ORT_CLASS_RELEASE(ModelMetadata); /* - * Creates an environment with global threadpools that will be shared across sessions. - * Use this in conjunction with DisablePerSessionThreads API or else the session will use - * its own thread pools. - */ + * Creates an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with DisablePerSessionThreads API or else the session will use + * its own thread pools. + */ ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_logging_level, _In_ const char *logid, _In_ const OrtThreadingOptions *t_options, _Outptr_ OrtEnv **out); /* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */ /* - * Calling this API will make the session use the global threadpools shared across sessions. - * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. - */ + * Calling this API will make the session use the global threadpools shared across sessions. + * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. + */ ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions *options); ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions **out); @@ -842,11 +843,11 @@ struct OrtApi { ORT_CLASS_RELEASE(ThreadingOptions); /** - * \param num_keys contains the number of keys in the custom metadata map - * \param keys is an array of null terminated strings (array count = num_keys) allocated using 'allocator'. - * The caller is responsible for freeing each string and the pointer array. - * 'keys' will be a nullptr if custom metadata map is empty. - */ + * \param num_keys contains the number of keys in the custom metadata map + * \param keys is an array of null terminated strings (array count = num_keys) allocated using 'allocator'. + * The caller is responsible for freeing each string and the pointer array. + * 'keys' will be a nullptr if custom metadata map is empty. + */ ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata *model_metadata, _Inout_ OrtAllocator *allocator, _Outptr_result_buffer_maybenull_(*num_keys) char ***keys, _Out_ int64_t *num_keys); @@ -858,21 +859,21 @@ struct OrtApi { _In_ int64_t dim_value); /** - * \param out_ptr will hold a pointer to the array of char * - * representing available providers. - * \param provider_length is a pointer to an int variable where - * the number of available providers will be added. - * The caller is responsible for freeing each char * and the pointer - * array by calling ReleaseAvailableProviders(). - */ + * \param out_ptr will hold a pointer to the array of char * + * representing available providers. + * \param provider_length is a pointer to an int variable where + * the number of available providers will be added. + * The caller is responsible for freeing each char * and the pointer + * array by calling ReleaseAvailableProviders(). + */ ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char ***out_ptr, _In_ int *provider_length); /** - * \param ptr is the pointer to an array of available providers you - * get after calling GetAvailableProviders(). - * \param providers_length is the number of available providers. - */ + * \param ptr is the pointer to an array of available providers you + * get after calling GetAvailableProviders(). + * \param providers_length is the number of available providers. + */ ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char **ptr, _In_ int providers_length); @@ -909,12 +910,12 @@ struct OrtApi { _In_z_ const char *config_key, _In_z_ const char *config_value); /** - * \param sess valid OrtSession instance - * \param mem_info - valid OrtMemoryInfo instance - * \param - out a ptr to a new instance of OrtAllocator according to the spec within mem_info - * if successful - * \return OrtStatus or nullptr if successful - */ + * \param sess valid OrtSession instance + * \param mem_info - valid OrtMemoryInfo instance + * \param - out a ptr to a new instance of OrtAllocator according to the spec within mem_info + * if successful + * \return OrtStatus or nullptr if successful + */ ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession *sess, _In_ const OrtMemoryInfo *mem_info, _Outptr_ OrtAllocator **out); @@ -932,131 +933,131 @@ struct OrtApi { ORT_CLASS_RELEASE(IoBinding); /** - * The function will bind the OrtValue to a specified input name. - * The OrtValue must be a Tensor. ORT would use that value in place of input for the specified name. - * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() - * \param name - name for the model input - * \param val_ptr - OrtValue of Tensor type. - * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success - */ + * The function will bind the OrtValue to a specified input name. + * The OrtValue must be a Tensor. ORT would use that value in place of input for the specified name. + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model input + * \param val_ptr - OrtValue of Tensor type. + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding *binding_ptr, _In_ const char *name, _In_ const OrtValue *val_ptr); /** - * The function will bind the OrtValue to the specified output name. - * The OrtValue must be a Tensor. ORT would use that value in place of output for the specified name. - * - * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() - * \param name - name for the model output - * \param val_ptr - OrtValue of Tensor type. - * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success - */ + * The function will bind the OrtValue to the specified output name. + * The OrtValue must be a Tensor. ORT would use that value in place of output for the specified name. + * + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model output + * \param val_ptr - OrtValue of Tensor type. + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding *binding_ptr, _In_ const char *name, _In_ const OrtValue *val_ptr); /** - * The function will bind the OrtValue to a device which specification is contained within OrtMemoryInfo - * You can either create an instance of OrtMemoryInfo with a device id or obtain one from the allocator that you are created/using - * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocated and bind a chunk of - * memory within OrtValue ahead of time. - * - * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() - * \param name - name for the model output - * \param mem_info_ptr - OrtMemoryInfo - * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success - */ + * The function will bind the OrtValue to a device which specification is contained within OrtMemoryInfo + * You can either create an instance of OrtMemoryInfo with a device id or obtain one from the allocator that you are created/using + * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocated and bind a chunk of + * memory within OrtValue ahead of time. + * + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model output + * \param mem_info_ptr - OrtMemoryInfo + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding *binding_ptr, _In_ const char *name, _In_ const OrtMemoryInfo *val_ptr); /** - * The function returns the names of the outputs in the order they were bound. This is useful after running the model - * with bound outputs because the returned names are in order in which output OrtValues are returned. This API is optional - * to use. If you knew the order of outputs and its names you used for binding you would not need to use this API. - * - * \param binding_ptr - a ptr to an instance of OrtIoBinding created obtained from CreateIoBinding() - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param buffer - pointer to a continuous buffer of non-zero terminated UTF-8 encoded strings. The number of strings stored is returned count parameter. - * this buffer will be allocated with the specified allocator and must be freed after it is no longer needed. - * \param lengths - a pointer to a continuous buffer of size_t lengths of strings returned in the buffer. The number of items is returned - * in the count. This buffer is allocated with the specified allocator and must be freed after it is no longer needed. - * \para count - is the number of strings returned. If the instance of OrtIoBiding has no bound outputs, zero is returned, - * no memory allocation is performed and buffer and lengths are nullptr on return. - */ + * The function returns the names of the outputs in the order they were bound. This is useful after running the model + * with bound outputs because the returned names are in order in which output OrtValues are returned. This API is optional + * to use. If you knew the order of outputs and its names you used for binding you would not need to use this API. + * + * \param binding_ptr - a ptr to an instance of OrtIoBinding created obtained from CreateIoBinding() + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param buffer - pointer to a continuous buffer of non-zero terminated UTF-8 encoded strings. The number of strings stored is returned count parameter. + * this buffer will be allocated with the specified allocator and must be freed after it is no longer needed. + * \param lengths - a pointer to a continuous buffer of size_t lengths of strings returned in the buffer. The number of items is returned + * in the count. This buffer is allocated with the specified allocator and must be freed after it is no longer needed. + * \para count - is the number of strings returned. If the instance of OrtIoBiding has no bound outputs, zero is returned, + * no memory allocation is performed and buffer and lengths are nullptr on return. + */ ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding *binding_ptr, _In_ OrtAllocator *allocator, _Out_ char **buffer, _Out_writes_all_(count) size_t **lengths, _Out_ size_t *count); /** - * The function returns an array of pointers to individually allocated OrtValues that contain results of a model execution with RunWithBinding() - * The array contains the same number of OrtValues and they are in the same order as they were bound with BindOutput() - * or BindOutputToDevice(). - * The returned OrtValues must be individually released after they are no longer needed. - * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after - * all the OrtValues contained therein are individually released. - * - * \param binding_ptr - instance of OrtIoBidning - * \param allocator - instance of allocator to allocate output array - * \param output - pointer to the allocated buffer. Returns nullptr if no outputs. - * \param output_count - pointer to the number of OrtValues returned. Zero if no outputs. - */ + * The function returns an array of pointers to individually allocated OrtValues that contain results of a model execution with RunWithBinding() + * The array contains the same number of OrtValues and they are in the same order as they were bound with BindOutput() + * or BindOutputToDevice(). + * The returned OrtValues must be individually released after they are no longer needed. + * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after + * all the OrtValues contained therein are individually released. + * + * \param binding_ptr - instance of OrtIoBidning + * \param allocator - instance of allocator to allocate output array + * \param output - pointer to the allocated buffer. Returns nullptr if no outputs. + * \param output_count - pointer to the number of OrtValues returned. Zero if no outputs. + */ ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding *binding_ptr, _In_ OrtAllocator *allocator, _Out_writes_all_(output_count) OrtValue ***output, _Out_ size_t *output_count); /** Clears any previously specified bindings for inputs/outputs - */ + */ void(ORT_API_CALL *ClearBoundInputs)(_Inout_ OrtIoBinding *binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; void(ORT_API_CALL *ClearBoundOutputs)(_Inout_ OrtIoBinding *binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** - * Provides element-level access into a tensor. - * \param location_values a pointer to an array of index values that specify an element's location in the tensor data blob - * \param location_values_count length of location_values - * \param out a pointer to the element specified by location_values - * e.g. - * Given a tensor with overall shape [3,224,224], an element at - * location [2,150,128] can be accessed directly. - * - * This function only works for numeric tensors. - * This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. - */ + * Provides element-level access into a tensor. + * \param location_values a pointer to an array of index values that specify an element's location in the tensor data blob + * \param location_values_count length of location_values + * \param out a pointer to the element specified by location_values + * e.g. + * Given a tensor with overall shape [3,224,224], an element at + * location [2,150,128] can be accessed directly. + * + * This function only works for numeric tensors. + * This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. + */ ORT_API2_STATUS(TensorAt, _Inout_ OrtValue *value, const int64_t *location_values, size_t location_values_count, _Outptr_ void **out); /** - * Creates an allocator instance and registers it with the env to enable - * sharing between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same OrtMemoryInfo is already registered. - * \param mem_info must be non-null. - * \param arena_cfg if nullptr defaults will be used. - * See docs/C_API.md for details. - */ + * Creates an allocator instance and registers it with the env to enable + * sharing between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same OrtMemoryInfo is already registered. + * \param mem_info must be non-null. + * \param arena_cfg if nullptr defaults will be used. + * See docs/C_API.md for details. + */ ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv *env, _In_ const OrtMemoryInfo *mem_info, _In_ const OrtArenaCfg *arena_cfg); /** - * Set the language projection for collecting telemetry data when Env is created - * \param projection the source projected language. - */ + * Set the language projection for collecting telemetry data when Env is created + * \param projection the source projected language. + */ ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv *ort_env, _In_ OrtLanguageProjection projection); /** - * \param out is set to the nanoseconds of profiling's start time - */ + * \param out is set to the nanoseconds of profiling's start time + */ ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession *sess, _Outptr_ uint64_t *out); /** - * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. - * A value of 0 means ORT will pick the default. - * A value of 1 means the invoking thread will be used; no threads will be created in the thread pool. - */ + * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. + * A value of 0 means ORT will pick the default. + * A value of 1 means the invoking thread will be used; no threads will be created in the thread pool. + */ ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions *tp_options, int intra_op_num_threads); ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions *tp_options, int inter_op_num_threads); /** - * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. - * Allow spinning of thread pools when their queues are empty. This API will set the value for both - * inter_op and intra_op threadpools. - * \param allow_spinning valid values are 1 and 0. - * 1: threadpool will spin to wait for queue to become non-empty, 0: it won't spin. - * Prefer a value of 0 if your CPU usage is very high. - */ + * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. + * Allow spinning of thread pools when their queues are empty. This API will set the value for both + * inter_op and intra_op threadpools. + * \param allow_spinning valid values are 1 and 0. + * 1: threadpool will spin to wait for queue to become non-empty, 0: it won't spin. + * Prefer a value of 0 if your CPU usage is very high. + */ ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions *tp_options, int allow_spinning); }; @@ -1075,8 +1076,7 @@ class ONNXRuntime { using OrtSessionOptionsAppendExecutionProvider_CUDA_t = OrtStatus *(*)(OrtSessionOptions *options, int device_id); ONNXRuntime() - : dm_{"onnxruntime"}, api_(nullptr) - { + : dm_{"onnxruntime"}, api_(nullptr) { if (dm_.is_available()) { OrtGetApiBase_t get_ort_api_base = dm_.get_symbol("OrtGetApiBase"); const OrtApi *api = get_ort_api_base()->GetApi(ORT_API_VERSION); diff --git a/src/bb/dnn/picosha2.h b/src/bb/dnn/picosha2.h index bc00c743..10b70a9d 100644 --- a/src/bb/dnn/picosha2.h +++ b/src/bb/dnn/picosha2.h @@ -43,9 +43,13 @@ typedef unsigned char byte_t; static const size_t k_digest_size = 32; namespace detail { -inline byte_t mask_8bit(byte_t x) { return x & 0xff; } +inline byte_t mask_8bit(byte_t x) { + return x & 0xff; +} -inline word_t mask_32bit(word_t x) { return x & 0xffffffff; } +inline word_t mask_32bit(word_t x) { + return x & 0xffffffff; +} const word_t add_constant[64] = { 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, @@ -64,7 +68,9 @@ const word_t initial_message_digest[8] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19}; -inline word_t ch(word_t x, word_t y, word_t z) { return (x & y) ^ ((~x) & z); } +inline word_t ch(word_t x, word_t y, word_t z) { + return (x & y) ^ ((~x) & z); +} inline word_t maj(word_t x, word_t y, word_t z) { return (x & y) ^ (x & z) ^ (y & z); @@ -75,20 +81,28 @@ inline word_t rotr(word_t x, std::size_t n) { return mask_32bit((x >> n) | (x << (32 - n))); } -inline word_t bsig0(word_t x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); } +inline word_t bsig0(word_t x) { + return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); +} -inline word_t bsig1(word_t x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); } +inline word_t bsig1(word_t x) { + return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); +} inline word_t shr(word_t x, std::size_t n) { assert(n < 32); return x >> n; } -inline word_t ssig0(word_t x) { return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); } +inline word_t ssig0(word_t x) { + return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); +} -inline word_t ssig1(word_t x) { return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); } +inline word_t ssig1(word_t x) { + return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); +} -template +template void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) { assert(first + 64 == last); static_cast(last); // for avoiding unused-variable warning @@ -141,8 +155,8 @@ void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) { } // namespace detail -template -void output_hex(InIter first, InIter last, std::ostream& os) { +template +void output_hex(InIter first, InIter last, std::ostream &os) { os.setf(std::ios::hex, std::ios::basefield); while (first != last) { os.width(2); @@ -153,35 +167,37 @@ void output_hex(InIter first, InIter last, std::ostream& os) { os.setf(std::ios::dec, std::ios::basefield); } -template -void bytes_to_hex_string(InIter first, InIter last, std::string& hex_str) { +template +void bytes_to_hex_string(InIter first, InIter last, std::string &hex_str) { std::ostringstream oss; output_hex(first, last, oss); hex_str.assign(oss.str()); } -template -void bytes_to_hex_string(const InContainer& bytes, std::string& hex_str) { +template +void bytes_to_hex_string(const InContainer &bytes, std::string &hex_str) { bytes_to_hex_string(bytes.begin(), bytes.end(), hex_str); } -template +template std::string bytes_to_hex_string(InIter first, InIter last) { std::string hex_str; bytes_to_hex_string(first, last, hex_str); return hex_str; } -template -std::string bytes_to_hex_string(const InContainer& bytes) { +template +std::string bytes_to_hex_string(const InContainer &bytes) { std::string hex_str; bytes_to_hex_string(bytes, hex_str); return hex_str; } class hash256_one_by_one { - public: - hash256_one_by_one() { init(); } +public: + hash256_one_by_one() { + init(); + } void init() { buffer_.clear(); @@ -190,7 +206,7 @@ class hash256_one_by_one { detail::initial_message_digest + 8, h_); } - template + template void process(RaIter first, RaIter last) { add_to_data_length(static_cast(std::distance(first, last))); std::copy(first, last, std::back_inserter(buffer_)); @@ -221,9 +237,9 @@ class hash256_one_by_one { detail::hash256_block(h_, temp, temp + 64); } - template + template void get_hash_bytes(OutIter first, OutIter last) const { - for (const word_t* iter = h_; iter != h_ + 8; ++iter) { + for (const word_t *iter = h_; iter != h_ + 8; ++iter) { for (std::size_t i = 0; i < 4 && first != last; ++i) { *(first++) = detail::mask_8bit( static_cast((*iter >> (24 - 8 * i)))); @@ -231,7 +247,7 @@ class hash256_one_by_one { } } - private: +private: void add_to_data_length(word_t n) { word_t carry = 0; data_length_digits_[0] += n; @@ -245,7 +261,7 @@ class hash256_one_by_one { } } } - void write_data_bit_length(byte_t* begin) { + void write_data_bit_length(byte_t *begin) { word_t data_bit_length_digits[4]; std::copy(data_length_digits_, data_length_digits_ + 4, data_bit_length_digits); @@ -271,21 +287,21 @@ class hash256_one_by_one { word_t h_[8]; }; -inline void get_hash_hex_string(const hash256_one_by_one& hasher, - std::string& hex_str) { +inline void get_hash_hex_string(const hash256_one_by_one &hasher, + std::string &hex_str) { byte_t hash[k_digest_size]; hasher.get_hash_bytes(hash, hash + k_digest_size); return bytes_to_hex_string(hash, hash + k_digest_size, hex_str); } -inline std::string get_hash_hex_string(const hash256_one_by_one& hasher) { +inline std::string get_hash_hex_string(const hash256_one_by_one &hasher) { std::string hex_str; get_hash_hex_string(hasher, hex_str); return hex_str; } namespace impl { -template +template void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int, std::random_access_iterator_tag) { hash256_one_by_one hasher; @@ -295,7 +311,7 @@ void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int, hasher.get_hash_bytes(first2, last2); } -template +template void hash256_impl(InputIter first, InputIter last, OutIter first2, OutIter last2, int buffer_size, std::input_iterator_tag) { std::vector buffer(buffer_size); @@ -315,9 +331,9 @@ void hash256_impl(InputIter first, InputIter last, OutIter first2, hasher.finish(); hasher.get_hash_bytes(first2, last2); } -} +} // namespace impl -template +template void hash256(InIter first, InIter last, OutIter first2, OutIter last2, int buffer_size = PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR) { picosha2::impl::hash256_impl( @@ -325,23 +341,23 @@ void hash256(InIter first, InIter last, OutIter first2, OutIter last2, typename std::iterator_traits::iterator_category()); } -template -void hash256(InIter first, InIter last, OutContainer& dst) { +template +void hash256(InIter first, InIter last, OutContainer &dst) { hash256(first, last, dst.begin(), dst.end()); } -template -void hash256(const InContainer& src, OutIter first, OutIter last) { +template +void hash256(const InContainer &src, OutIter first, OutIter last) { hash256(src.begin(), src.end(), first, last); } -template -void hash256(const InContainer& src, OutContainer& dst) { +template +void hash256(const InContainer &src, OutContainer &dst) { hash256(src.begin(), src.end(), dst.begin(), dst.end()); } -template -void hash256_hex_string(InIter first, InIter last, std::string& hex_str) { +template +void hash256_hex_string(InIter first, InIter last, std::string &hex_str) { byte_t hashed[k_digest_size]; hash256(first, last, hashed, hashed + k_digest_size); std::ostringstream oss; @@ -349,29 +365,29 @@ void hash256_hex_string(InIter first, InIter last, std::string& hex_str) { hex_str.assign(oss.str()); } -template +template std::string hash256_hex_string(InIter first, InIter last) { std::string hex_str; hash256_hex_string(first, last, hex_str); return hex_str; } -inline void hash256_hex_string(const std::string& src, std::string& hex_str) { +inline void hash256_hex_string(const std::string &src, std::string &hex_str) { hash256_hex_string(src.begin(), src.end(), hex_str); } -template -void hash256_hex_string(const InContainer& src, std::string& hex_str) { +template +void hash256_hex_string(const InContainer &src, std::string &hex_str) { hash256_hex_string(src.begin(), src.end(), hex_str); } -template -std::string hash256_hex_string(const InContainer& src) { +template +std::string hash256_hex_string(const InContainer &src) { return hash256_hex_string(src.begin(), src.end()); } -templatevoid hash256(std::ifstream& f, OutIter first, OutIter last){ - hash256(std::istreambuf_iterator(f), std::istreambuf_iterator(), first,last); - +template +void hash256(std::ifstream &f, OutIter first, OutIter last) { + hash256(std::istreambuf_iterator(f), std::istreambuf_iterator(), first, last); } -}// namespace picosha2 +} // namespace picosha2 #endif // PICOSHA2_H diff --git a/src/bb/dnn/rt.h b/src/bb/dnn/rt.h index 3808a1a0..31540fe2 100644 --- a/src/bb/dnn/rt.h +++ b/src/bb/dnn/rt.h @@ -22,15 +22,15 @@ namespace dnn { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // image_io -} // bb -} // ion +} // namespace dnn +} // namespace bb +} // namespace ion #define ION_REGISTER_EXTERN(NAME) static auto ion_register_extern_##NAME = ion::bb::dnn::RegisterExtern(#NAME, NAME); extern "C" ION_EXPORT int ion_bb_dnn_generic_object_detection(halide_buffer_t *in, @@ -273,7 +273,7 @@ extern "C" ION_EXPORT int ion_bb_dnn_classify_gender(halide_buffer_t *in_img, return 0; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return -1; } catch (...) { diff --git a/src/bb/dnn/rt_json.h b/src/bb/dnn/rt_json.h index 988f49d5..12cf36a9 100644 --- a/src/bb/dnn/rt_json.h +++ b/src/bb/dnn/rt_json.h @@ -16,159 +16,160 @@ namespace dnn { namespace json { class DictAverageRegurator { - public: - static DictAverageRegurator& get_instance(const std::string& session_id, uint32_t period_in_sec) { - static std::unordered_map> instances; - if (instances.count(session_id) == 0) { - instances[session_id] = std::unique_ptr(new DictAverageRegurator(period_in_sec)); - } - return *instances[session_id].get(); - } - - nlohmann::json process(nlohmann::json in) { - using js = nlohmann::json; - - if (!in.is_object()) { - throw std::runtime_error("Unexpected data format: input is not an object"); - } - - for (js::iterator it = in.begin(); it != in.end(); ++it) { - if (!it.value().is_number()) { - throw std::runtime_error("Unexpected format: value is not a number"); - } - - if (data_.count(it.key()) == 0) { - data_[it.key()] = 0.0f; - } - - data_[it.key()] = data_[it.key()] + static_cast(it.value()); - } - count_++; - - auto now = std::chrono::system_clock::now(); - if (std::chrono::duration_cast(now - tp_).count() >= period_in_sec_) { - js j; - for (auto& d : data_) { - std::stringstream ss; - ss << std::fixed << std::setprecision(2) << d.second / static_cast(count_); - j[d.first] = ss.str(); - } - data_.clear(); - tp_ = now; - count_ = 0; - return j; - } else { - return js(); - } - - return in; - } - - private: - DictAverageRegurator(uint32_t period_in_sec) : period_in_sec_(period_in_sec), tp_(std::chrono::system_clock::now()), count_(0) {} - uint32_t period_in_sec_; - std::unordered_map data_; - std::chrono::time_point tp_; - uint32_t count_; +public: + static DictAverageRegurator &get_instance(const std::string &session_id, uint32_t period_in_sec) { + static std::unordered_map> instances; + if (instances.count(session_id) == 0) { + instances[session_id] = std::unique_ptr(new DictAverageRegurator(period_in_sec)); + } + return *instances[session_id].get(); + } + + nlohmann::json process(nlohmann::json in) { + using js = nlohmann::json; + + if (!in.is_object()) { + throw std::runtime_error("Unexpected data format: input is not an object"); + } + + for (js::iterator it = in.begin(); it != in.end(); ++it) { + if (!it.value().is_number()) { + throw std::runtime_error("Unexpected format: value is not a number"); + } + + if (data_.count(it.key()) == 0) { + data_[it.key()] = 0.0f; + } + + data_[it.key()] = data_[it.key()] + static_cast(it.value()); + } + count_++; + + auto now = std::chrono::system_clock::now(); + if (std::chrono::duration_cast(now - tp_).count() >= period_in_sec_) { + js j; + for (auto &d : data_) { + std::stringstream ss; + ss << std::fixed << std::setprecision(2) << d.second / static_cast(count_); + j[d.first] = ss.str(); + } + data_.clear(); + tp_ = now; + count_ = 0; + return j; + } else { + return js(); + } + + return in; + } + +private: + DictAverageRegurator(uint32_t period_in_sec) + : period_in_sec_(period_in_sec), tp_(std::chrono::system_clock::now()), count_(0) { + } + uint32_t period_in_sec_; + std::unordered_map data_; + std::chrono::time_point tp_; + uint32_t count_; }; class WebHookUploader { - public: - static WebHookUploader& get_instance(const std::string& session_id, const std::string& url) { - static std::unordered_map> instances; - if (instances.count(session_id) == 0) { - instances[session_id] = std::unique_ptr(new WebHookUploader(url)); - } - return *instances[session_id].get(); - } - - void upload(nlohmann::json in) { - using js = nlohmann::json; - if (in.is_null()) { - return; - } - - js j; - j["value1"] = in.dump(); - - std::unique_lock lock(mutex_); - if (ep_) { - std::rethrow_exception(ep_); - } - - queue_.push(j.dump()); - cv_.notify_one(); - } - - ~WebHookUploader() { - if (thread_->joinable()) { - keep_running_ = false; - cv_.notify_one(); - thread_->join(); - } - } - - private: - WebHookUploader(const std::string& url) +public: + static WebHookUploader &get_instance(const std::string &session_id, const std::string &url) { + static std::unordered_map> instances; + if (instances.count(session_id) == 0) { + instances[session_id] = std::unique_ptr(new WebHookUploader(url)); + } + return *instances[session_id].get(); + } + + void upload(nlohmann::json in) { + using js = nlohmann::json; + if (in.is_null()) { + return; + } + + js j; + j["value1"] = in.dump(); + + std::unique_lock lock(mutex_); + if (ep_) { + std::rethrow_exception(ep_); + } + + queue_.push(j.dump()); + cv_.notify_one(); + } + + ~WebHookUploader() { + if (thread_->joinable()) { + keep_running_ = false; + cv_.notify_one(); + thread_->join(); + } + } + +private: + WebHookUploader(const std::string &url) : keep_running_(true) { - std::string host_name; - std::tie(host_name, path_name_) = parse_url(url); - if (host_name.empty() || path_name_.empty()) { - throw std::runtime_error("Invalid URL : " + url); - } - - cli_ = std::unique_ptr(new httplib::Client(host_name.c_str())); - if (!cli_->is_valid()) { - throw std::runtime_error("Failed to create HTTP client : " + url); - } - - thread_ = std::unique_ptr(new std::thread(entry_point, this)); - }; - - static void entry_point(WebHookUploader* obj) { - try { - obj->thread_main(); - } - catch (...) { - std::unique_lock lock(obj->mutex_); - obj->ep_ = std::current_exception(); - } - } - - void thread_main() { - while (true) { - std::string body; - { - std::unique_lock lock(mutex_); - cv_.wait(lock, [&] { return !queue_.empty() || !keep_running_; }); - if (!keep_running_) { - break; - } - body = queue_.front(); - queue_.pop(); - } - - auto res = cli_->Post(path_name_.c_str(), body, "application/json"); - if (!res || res->status != 200) { - throw std::runtime_error("Failed to upload data"); - } - } - } - - std::unique_ptr cli_; - std::string path_name_; - - std::unique_ptr thread_; - std::queue queue_; - std::mutex mutex_; - std::condition_variable cv_; - bool keep_running_; - std::exception_ptr ep_; + std::string host_name; + std::tie(host_name, path_name_) = parse_url(url); + if (host_name.empty() || path_name_.empty()) { + throw std::runtime_error("Invalid URL : " + url); + } + + cli_ = std::unique_ptr(new httplib::Client(host_name.c_str())); + if (!cli_->is_valid()) { + throw std::runtime_error("Failed to create HTTP client : " + url); + } + + thread_ = std::unique_ptr(new std::thread(entry_point, this)); + }; + + static void entry_point(WebHookUploader *obj) { + try { + obj->thread_main(); + } catch (...) { + std::unique_lock lock(obj->mutex_); + obj->ep_ = std::current_exception(); + } + } + + void thread_main() { + while (true) { + std::string body; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return !queue_.empty() || !keep_running_; }); + if (!keep_running_) { + break; + } + body = queue_.front(); + queue_.pop(); + } + + auto res = cli_->Post(path_name_.c_str(), body, "application/json"); + if (!res || res->status != 200) { + throw std::runtime_error("Failed to upload data"); + } + } + } + + std::unique_ptr cli_; + std::string path_name_; + + std::unique_ptr thread_; + std::queue queue_; + std::mutex mutex_; + std::condition_variable cv_; + bool keep_running_; + std::exception_ptr ep_; }; -} // json -} // dnn -} // bb -} // ion +} // namespace json +} // namespace dnn +} // namespace bb +} // namespace ion #endif diff --git a/src/bb/dnn/rt_opencv.h b/src/bb/dnn/rt_opencv.h index 7d2e2dfb..88d7459c 100644 --- a/src/bb/dnn/rt_opencv.h +++ b/src/bb/dnn/rt_opencv.h @@ -22,124 +22,122 @@ using json = nlohmann::json; using ClassifyResult = std::unordered_map; class Classifier { - public: - static Classifier& get_instance(const std::string& uuid, const std::string& model_root_url, const std::string& cache_root) { - static std::map> map_; - Classifier *c; - if (map_.count(uuid) == 0) { - map_[uuid] = std::unique_ptr(new Classifier(model_root_url, cache_root)); - } - return *map_[uuid].get(); - } - - ClassifyResult classify( - const cv::Mat& image, - const std::vector& boxes) { - ClassifyResult result; - - const int PeopleNetClassID_Face = 2; - const cv::Scalar MODEL_MEAN_VALUES = cv::Scalar(78.4263377603, 87.7689143744, 114.895847746); - - result["Male"] = 0; - result["Female"] = 0; - - for (auto b: boxes) { - if (b.class_id == PeopleNetClassID_Face) { - if (b.x2-b.x1 < 100 || b.y2-b.y1 < 100) { - continue; - } - cv::Mat face(image, cv::Rect(b.x1, b.y1, b.x2-b.x1, b.y2-b.y1)); - cv::normalize(face, face, 0, 255, cv::NORM_MINMAX, CV_8UC3); - cv::cvtColor(face, face, cv::COLOR_RGB2BGR); - - cv::Mat blob = cv::dnn::blobFromImage(face, 1, cv::Size(227, 227), MODEL_MEAN_VALUES, false); - net_.setInput(blob); - // // string gender_preds; - std::vector genderPreds = net_.forward(); - // // printing gender here - // // find max element index - // // distance function does the argmax() work in C++ - const char *genderList[] = {"Male", "Female"}; - int max_index_gender = std::distance(genderPreds.begin(), max_element(genderPreds.begin(), genderPreds.end())); - std::string gender = genderList[max_index_gender]; - result[gender]++; - } - } - return result; - } - - private: - - std::string cache_load(const std::string& model_root_url, const std::string& file_name, const std::string& cache_root) { - const std::string url = model_root_url + file_name; - - std::vector hash(picosha2::k_digest_size); - picosha2::hash256(url.begin(), url.end(), hash.begin(), hash.end()); - auto hash_str = picosha2::bytes_to_hex_string(hash.begin(), hash.end()); - - auto path = cache_root + file_name + "." + hash_str; - - std::ifstream ifs(path, std::ios::binary); - if (ifs.is_open()) { - return path; - } - ifs.close(); - - std::string host_name; - std::string path_name; - std::tie(host_name, path_name) = parse_url(url); - if (host_name.empty() || path_name.empty()) { - throw std::runtime_error("Invalid URL : " + url); - } - - httplib::Client cli(host_name.c_str()); - cli.set_follow_location(true); - auto res = cli.Get(path_name.c_str()); - if (!res || res->status != 200) { - throw std::runtime_error("Failed to download file: " + url); - } - - std::ofstream ofs(path, std::ios::binary); - ofs.write(res->body.c_str(), res->body.size()); - - return path; - } - - Classifier(const std::string& model_root_url, const std::string& cache_root) - { - auto model_define = cache_load(model_root_url, "model_define.prototxt", cache_root); - auto model_weight = cache_load(model_root_url, "model_weight.caffemodel", cache_root); - net_ = cv::dnn::readNet(model_weight, model_define, "caffe"); - } - - cv::dnn::Net net_; +public: + static Classifier &get_instance(const std::string &uuid, const std::string &model_root_url, const std::string &cache_root) { + static std::map> map_; + Classifier *c; + if (map_.count(uuid) == 0) { + map_[uuid] = std::unique_ptr(new Classifier(model_root_url, cache_root)); + } + return *map_[uuid].get(); + } + + ClassifyResult classify( + const cv::Mat &image, + const std::vector &boxes) { + ClassifyResult result; + + const int PeopleNetClassID_Face = 2; + const cv::Scalar MODEL_MEAN_VALUES = cv::Scalar(78.4263377603, 87.7689143744, 114.895847746); + + result["Male"] = 0; + result["Female"] = 0; + + for (auto b : boxes) { + if (b.class_id == PeopleNetClassID_Face) { + if (b.x2 - b.x1 < 100 || b.y2 - b.y1 < 100) { + continue; + } + cv::Mat face(image, cv::Rect(b.x1, b.y1, b.x2 - b.x1, b.y2 - b.y1)); + cv::normalize(face, face, 0, 255, cv::NORM_MINMAX, CV_8UC3); + cv::cvtColor(face, face, cv::COLOR_RGB2BGR); + + cv::Mat blob = cv::dnn::blobFromImage(face, 1, cv::Size(227, 227), MODEL_MEAN_VALUES, false); + net_.setInput(blob); + // // string gender_preds; + std::vector genderPreds = net_.forward(); + // // printing gender here + // // find max element index + // // distance function does the argmax() work in C++ + const char *genderList[] = {"Male", "Female"}; + int max_index_gender = std::distance(genderPreds.begin(), max_element(genderPreds.begin(), genderPreds.end())); + std::string gender = genderList[max_index_gender]; + result[gender]++; + } + } + return result; + } + +private: + std::string cache_load(const std::string &model_root_url, const std::string &file_name, const std::string &cache_root) { + const std::string url = model_root_url + file_name; + + std::vector hash(picosha2::k_digest_size); + picosha2::hash256(url.begin(), url.end(), hash.begin(), hash.end()); + auto hash_str = picosha2::bytes_to_hex_string(hash.begin(), hash.end()); + + auto path = cache_root + file_name + "." + hash_str; + + std::ifstream ifs(path, std::ios::binary); + if (ifs.is_open()) { + return path; + } + ifs.close(); + + std::string host_name; + std::string path_name; + std::tie(host_name, path_name) = parse_url(url); + if (host_name.empty() || path_name.empty()) { + throw std::runtime_error("Invalid URL : " + url); + } + + httplib::Client cli(host_name.c_str()); + cli.set_follow_location(true); + auto res = cli.Get(path_name.c_str()); + if (!res || res->status != 200) { + throw std::runtime_error("Failed to download file: " + url); + } + + std::ofstream ofs(path, std::ios::binary); + ofs.write(res->body.c_str(), res->body.size()); + + return path; + } + + Classifier(const std::string &model_root_url, const std::string &cache_root) { + auto model_define = cache_load(model_root_url, "model_define.prototxt", cache_root); + auto model_weight = cache_load(model_root_url, "model_weight.caffemodel", cache_root); + net_ = cv::dnn::readNet(model_weight, model_define, "caffe"); + } + + cv::dnn::Net net_; }; void classify_gender(halide_buffer_t *in_img, halide_buffer_t *in_md, int32_t output_size, - const std::string& session_id, - const std::string& model_root_url, - const std::string& cache_root, + const std::string &session_id, + const std::string &model_root_url, + const std::string &cache_root, halide_buffer_t *out) { using namespace cv; using json = nlohmann::json; - auto& classifier = Classifier::get_instance(session_id, model_root_url, cache_root); + auto &classifier = Classifier::get_instance(session_id, model_root_url, cache_root); const int width = in_img->dim[1].extent; const int height = in_img->dim[2].extent; cv::Mat image(height, width, CV_32FC3, in_img->host); - auto boxes = json::parse(reinterpret_cast(in_md->host)).get>(); + auto boxes = json::parse(reinterpret_cast(in_md->host)).get>(); ClassifyResult classify_result = classifier.classify(image, boxes); json j = classify_result; std::string output_string(j.dump()); - if (output_string.size()+1 >= output_size) { + if (output_string.size() + 1 >= output_size) { throw std::runtime_error("Output buffer size is not sufficient"); } @@ -149,10 +147,9 @@ void classify_gender(halide_buffer_t *in_img, return; } +} // namespace opencv +} // namespace dnn +} // namespace bb +} // namespace ion -} // cv -} // dnn -} // bb -} // ion - -#endif // ION_BB_DNN_RT_OPENCV_H +#endif // ION_BB_DNN_RT_OPENCV_H diff --git a/src/bb/dnn/rt_ort.h b/src/bb/dnn/rt_ort.h index cca45c7e..fdec741b 100644 --- a/src/bb/dnn/rt_ort.h +++ b/src/bb/dnn/rt_ort.h @@ -23,7 +23,7 @@ namespace dnn { class OrtSessionManager { public: - OrtSessionManager(const std::string& model_root_url, const std ::string &cache_root, bool cuda_enable) + OrtSessionManager(const std::string &model_root_url, const std ::string &cache_root, bool cuda_enable) : ort_{new ONNXRuntime()} { const OrtApi *api = ort_->get_api(); ort_->check_status(api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "ion_bb_dnn_ort", &env)); @@ -39,7 +39,7 @@ class OrtSessionManager { // ort_->enable_tensorrt_provider(session_options, 0); // } - //std::string model_url = model_root_url + "yolov4-tiny_416_416.onnx"; + // std::string model_url = model_root_url + "yolov4-tiny_416_416.onnx"; std::string model_name = "ssd_mobilenet_v2_coco_2018_03_29.onnx"; std::ifstream ifs(cache_root + model_name, std::ios::binary); @@ -48,8 +48,8 @@ class OrtSessionManager { ifs.seekg(0, std::ios::end); auto end = ifs.tellg(); ifs.seekg(0, std::ios::beg); - model_.resize(end-begin); - ifs.read(reinterpret_cast(model_.data()), model_.size()); + model_.resize(end - begin); + ifs.read(reinterpret_cast(model_.data()), model_.size()); } else { std::string model_url = model_root_url + model_name; std::string host_name; @@ -71,8 +71,8 @@ class OrtSessionManager { model_.resize(res->body.size()); std::memcpy(model_.data(), res->body.c_str(), res->body.size()); - std::ofstream ofs (cache_root + model_name, std::ios::binary); - ofs.write(reinterpret_cast(model_.data()), model_.size()); + std::ofstream ofs(cache_root + model_name, std::ios::binary); + ofs.write(reinterpret_cast(model_.data()), model_.size()); } ort_->check_status(api->CreateSessionFromArray(env, model_.data(), model_.size(), session_options, &session)); @@ -90,7 +90,7 @@ class OrtSessionManager { return session; } - static OrtSessionManager *make(const std::string &uuid, const std::string& model_root_url, const std ::string &cache_root, bool cuda_enable) { + static OrtSessionManager *make(const std::string &uuid, const std::string &model_root_url, const std ::string &cache_root, bool cuda_enable) { static std::map> map_; OrtSessionManager *ort_manager; if (map_.count(uuid) == 0) { @@ -136,9 +136,9 @@ bool is_ort_available() { } int object_detection_ort(halide_buffer_t *in, - const std::string& session_id, - const std::string& model_root_url, - const std::string& cache_root, + const std::string &session_id, + const std::string &model_root_url, + const std::string &cache_root, bool cuda_enable, halide_buffer_t *out) { @@ -184,7 +184,7 @@ int object_detection_ort(halide_buffer_t *in, int num_images = in->dimensions == 3 ? 1 : in->dim[3].extent; - for (int i=0; ihost + offset); @@ -202,7 +202,7 @@ int object_detection_ort(halide_buffer_t *in, resized.convertTo(input_tensor_data, CV_8UC3, 255.0); - uint8_t *input_tensor_ptr = reinterpret_cast(input_tensor_data.ptr()); + uint8_t *input_tensor_ptr = reinterpret_cast(input_tensor_data.ptr()); OrtMemoryInfo *memory_info; ort->check_status(api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -213,7 +213,7 @@ int object_detection_ort(halide_buffer_t *in, assert(is_tensor); api->ReleaseMemoryInfo(memory_info); - //std::vector output_tensor_names = {"boxes", "confs"}; + // std::vector output_tensor_names = {"boxes", "confs"}; std::vector output_tensor_names = {"detection_boxes:0", "detection_classes:0", "detection_scores:0", "num_detections:0"}; std::vector output_tensors(4); ort->check_status(api->Run(session, NULL, &input_name, (const OrtValue *const *)&input_tensor, 1, output_tensor_names.data(), 4, output_tensors.data())); @@ -247,7 +247,7 @@ int object_detection_ort(halide_buffer_t *in, // const int num = 2535; // const int num_classes = 80; - //const auto prediceted_boxes = yolo_post_processing(boxes_ptr, confs_ptr, num, num_classes); + // const auto prediceted_boxes = yolo_post_processing(boxes_ptr, confs_ptr, num, num_classes); const auto prediceted_boxes = ssd_post_processing(boxes_ptr, classes_ptr, scores_ptr, static_cast(lround(*nums_ptr))); cv::Mat out_(height, width, CV_32FC3, out->host + offset); in_.copyTo(out_); @@ -268,4 +268,4 @@ int object_detection_ort(halide_buffer_t *in, } // namespace bb } // namespace ion -#endif // ION_BB_DNN_RT_ORT_H +#endif // ION_BB_DNN_RT_ORT_H diff --git a/src/bb/dnn/rt_tfl.h b/src/bb/dnn/rt_tfl.h index 8b30ed9e..47f4f388 100644 --- a/src/bb/dnn/rt_tfl.h +++ b/src/bb/dnn/rt_tfl.h @@ -17,23 +17,23 @@ namespace bb { namespace dnn { class TflSessionManager { - public: - static TflSessionManager& get_instance() { - static TflSessionManager instance; - return instance; - } - - struct TfLiteObjects { - // Need to hold model data permanently - std::shared_ptr> model_data; - std::shared_ptr model; - std::shared_ptr delegate; - std::shared_ptr options; - std::shared_ptr interpreter; - }; - - std::shared_ptr get_interpreter(const std::string& model_root_url, const std::string& cache_root) { - std::string model_name; +public: + static TflSessionManager &get_instance() { + static TflSessionManager instance; + return instance; + } + + struct TfLiteObjects { + // Need to hold model data permanently + std::shared_ptr> model_data; + std::shared_ptr model; + std::shared_ptr delegate; + std::shared_ptr options; + std::shared_ptr interpreter; + }; + + std::shared_ptr get_interpreter(const std::string &model_root_url, const std::string &cache_root) { + std::string model_name; if (is_available_edgetpu_) { model_name = "ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite"; } else { @@ -52,9 +52,9 @@ class TflSessionManager { ifs.seekg(0, std::ios::end); auto end = ifs.tellg(); ifs.seekg(0, std::ios::beg); - model_data = std::shared_ptr>(new std::vector(end-begin)); + model_data = std::shared_ptr>(new std::vector(end - begin)); ifs.read(reinterpret_cast(model_data->data()), model_data->size()); - } else { + } else { std::string host_name; std::string path_name; std::tie(host_name, path_name) = parse_url(model_url); @@ -74,8 +74,8 @@ class TflSessionManager { model_data = std::shared_ptr>(new std::vector(res->body.size())); std::memcpy(model_data->data(), res->body.c_str(), res->body.size()); - std::ofstream ofs (cache_root + model_name, std::ios::binary); - ofs.write(reinterpret_cast(model_data->data()), model_data->size()); + std::ofstream ofs(cache_root + model_name, std::ios::binary); + ofs.write(reinterpret_cast(model_data->data()), model_data->size()); } std::shared_ptr model(TfLiteModelCreate(model_data->data(), model_data->size()), TfLiteModelDelete); @@ -96,7 +96,7 @@ class TflSessionManager { std::cerr << "No device found" << std::endl; return nullptr; } - const auto& device = devices.get()[0]; + const auto &device = devices.get()[0]; // Create EdgeTpu delegate delegate = std::shared_ptr(edgetpu_create_delegate(device.type, device.path, nullptr, 0), edgetpu_free_delegate); @@ -112,25 +112,23 @@ class TflSessionManager { return nullptr; } - if (TfLiteInterpreterAllocateTensors(interpreter.get())!= kTfLiteOk) { + if (TfLiteInterpreterAllocateTensors(interpreter.get()) != kTfLiteOk) { std::cerr << "Failed to allocate tensors." << std::endl; return nullptr; } objects_[model_url] = TfLiteObjects{ - model_data, model, delegate, options, interpreter - }; + model_data, model, delegate, options, interpreter}; return interpreter; - } + } - bool is_available() { - return is_available_tflite_; - } + bool is_available() { + return is_available_tflite_; + } - private: +private: TflSessionManager() - : is_available_tflite_(false), is_available_edgetpu_(false) - { + : is_available_tflite_(false), is_available_edgetpu_(false) { if (!tensorflowlite_init()) { return; } @@ -153,8 +151,8 @@ bool is_tfl_available() { } int object_detection_tfl(halide_buffer_t *in, - const std::string& model_root_url, - const std::string& cache_root, + const std::string &model_root_url, + const std::string &cache_root, halide_buffer_t *out) { const int channel = 3; @@ -165,7 +163,7 @@ int object_detection_tfl(halide_buffer_t *in, int num_images = in->dimensions == 3 ? 1 : in->dim[3].extent; - for (int i=0; ihost + offset); @@ -176,7 +174,7 @@ int object_detection_tfl(halide_buffer_t *in, if (channel != TfLiteTensorDim(input, 3)) { std::cerr << "Input channel mismatches: " - << channel << " vs " << TfLiteTensorDim(input, 3) << std::endl; + << channel << " vs " << TfLiteTensorDim(input, 3) << std::endl; return -1; } @@ -190,10 +188,10 @@ int object_detection_tfl(halide_buffer_t *in, resized.convertTo(input_tensor_data, CV_8UC3, 255.0); - if ((3*input_tensor_data.total()) != TfLiteTensorByteSize(input)) { + if ((3 * input_tensor_data.total()) != TfLiteTensorByteSize(input)) { std::cerr << "Input size mismatches: " - << 3*input_tensor_data.total() << " vs " << TfLiteTensorByteSize(input) - << std::endl; + << 3 * input_tensor_data.total() << " vs " << TfLiteTensorByteSize(input) + << std::endl; return -1; } @@ -212,15 +210,15 @@ int object_detection_tfl(halide_buffer_t *in, return -1; } - const TfLiteTensor* boxes = TfLiteInterpreterGetOutputTensor(interpreter.get(), 0); - const TfLiteTensor* classes = TfLiteInterpreterGetOutputTensor(interpreter.get(), 1); - const TfLiteTensor* scores = TfLiteInterpreterGetOutputTensor(interpreter.get(), 2); - const TfLiteTensor* num = TfLiteInterpreterGetOutputTensor(interpreter.get(), 3); + const TfLiteTensor *boxes = TfLiteInterpreterGetOutputTensor(interpreter.get(), 0); + const TfLiteTensor *classes = TfLiteInterpreterGetOutputTensor(interpreter.get(), 1); + const TfLiteTensor *scores = TfLiteInterpreterGetOutputTensor(interpreter.get(), 2); + const TfLiteTensor *num = TfLiteInterpreterGetOutputTensor(interpreter.get(), 3); - float *boxes_ptr = reinterpret_cast(TfLiteTensorData(boxes)); - float *classes_ptr = reinterpret_cast(TfLiteTensorData(classes)); - float *scores_ptr = reinterpret_cast(TfLiteTensorData(scores)); - float *num_ptr = reinterpret_cast(TfLiteTensorData(num)); + float *boxes_ptr = reinterpret_cast(TfLiteTensorData(boxes)); + float *classes_ptr = reinterpret_cast(TfLiteTensorData(classes)); + float *scores_ptr = reinterpret_cast(TfLiteTensorData(scores)); + float *num_ptr = reinterpret_cast(TfLiteTensorData(num)); const auto detected_boxes = ssd_post_processing(boxes_ptr, classes_ptr, scores_ptr, static_cast(*num_ptr)); @@ -238,4 +236,4 @@ int object_detection_tfl(halide_buffer_t *in, } // namespace bb } // namespace ion -#endif // ION_BB_DNN_RT_TFL_H +#endif // ION_BB_DNN_RT_TFL_H diff --git a/src/bb/dnn/rt_util.h b/src/bb/dnn/rt_util.h index ad5d2889..f0beeb2f 100644 --- a/src/bb/dnn/rt_util.h +++ b/src/bb/dnn/rt_util.h @@ -13,7 +13,6 @@ #include #endif - #include namespace ion { diff --git a/src/bb/dnn/rt_yolo.h b/src/bb/dnn/rt_yolo.h index 94d71b21..3b68e137 100644 --- a/src/bb/dnn/rt_yolo.h +++ b/src/bb/dnn/rt_yolo.h @@ -63,8 +63,8 @@ std::vector yolo_post_processing(const float *boxes, const float * return detected_boxes; } -} // dnn -} // bb -} // ion +} // namespace dnn +} // namespace bb +} // namespace ion #endif diff --git a/src/bb/dnn/tensorflowlite_c.h b/src/bb/dnn/tensorflowlite_c.h index e80350c4..1e5a3685 100644 --- a/src/bb/dnn/tensorflowlite_c.h +++ b/src/bb/dnn/tensorflowlite_c.h @@ -15,59 +15,59 @@ typedef struct TfLiteModel TfLiteModel; typedef struct TfLiteInterpreterOptions TfLiteInterpreterOptions; typedef struct TfLiteInterpreter TfLiteInterpreter; -using TfLiteModelCreate_t = TfLiteModel* (*)(const void* model_data, size_t model_size); -using TfLiteModelCreateFromFile_t = TfLiteModel* (*)(const char* model_path); -using TfLiteModelDelete_t = void (*)(TfLiteModel* model); -using TfLiteInterpreterOptionsCreate_t = TfLiteInterpreterOptions* (*)(); -using TfLiteInterpreterOptionsDelete_t = void (*)( TfLiteInterpreterOptions* options); -using TfLiteInterpreterOptionsSetNumThreads_t = void (*)(TfLiteInterpreterOptions* options, int32_t num_threads); -using TfLiteInterpreterOptionsAddDelegate_t = void (*)(TfLiteInterpreterOptions* options, TfLiteDelegate* delegate); -using TfLiteInterpreterOptionsSetErrorReporter_t = void (*)( TfLiteInterpreterOptions* options, void (*reporter)(void* user_data, const char* format, va_list args), void* user_data); -using TfLiteInterpreterCreate_t = TfLiteInterpreter* (*)(const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options); -using TfLiteInterpreterDelete_t = void (*)(TfLiteInterpreter* interpreter); -using TfLiteInterpreterGetInputTensorCount_t = int32_t (*)(const TfLiteInterpreter* interpreter); -using TfLiteInterpreterGetInputTensor_t = TfLiteTensor* (*)(const TfLiteInterpreter* interpreter, int32_t input_index); -using TfLiteInterpreterResizeInputTensor_t = TfLiteStatus (*)(TfLiteInterpreter* interpreter, int32_t input_index, const int* input_dims, int32_t input_dims_size); -using TfLiteInterpreterAllocateTensors_t = TfLiteStatus (*)( TfLiteInterpreter* interpreter); -using TfLiteInterpreterInvoke_t = TfLiteStatus (*)( TfLiteInterpreter* interpreter); -using TfLiteInterpreterGetOutputTensorCount_t = int32_t (*)( const TfLiteInterpreter* interpreter); -using TfLiteInterpreterGetOutputTensor_t = const TfLiteTensor* (*)( const TfLiteInterpreter* interpreter, int32_t output_index); -using TfLiteTensorType_t = TfLiteType (*)(const TfLiteTensor* tensor); -using TfLiteTensorNumDims_t = int32_t (*)(const TfLiteTensor* tensor); -using TfLiteTensorDim_t = int32_t (*)(const TfLiteTensor* tensor, int32_t dim_index); -using TfLiteTensorByteSize_t = size_t (*)(const TfLiteTensor* tensor); -using TfLiteTensorData_t = void* (*)(const TfLiteTensor* tensor); -using TfLiteTensorName_t = const char* (*)(const TfLiteTensor* tensor); -using TfLiteTensorQuantizationParams_t = TfLiteQuantizationParams (*)(const TfLiteTensor* tensor); -using TfLiteTensorCopyFromBuffer_t = TfLiteStatus (*)(TfLiteTensor* tensor, const void* input_data, size_t input_data_size); -using TfLiteTensorCopyToBuffer_t = TfLiteStatus (*)(const TfLiteTensor* output_tensor, void* output_data, size_t output_data_size); +using TfLiteModelCreate_t = TfLiteModel *(*)(const void *model_data, size_t model_size); +using TfLiteModelCreateFromFile_t = TfLiteModel *(*)(const char *model_path); +using TfLiteModelDelete_t = void (*)(TfLiteModel *model); +using TfLiteInterpreterOptionsCreate_t = TfLiteInterpreterOptions *(*)(); +using TfLiteInterpreterOptionsDelete_t = void (*)(TfLiteInterpreterOptions *options); +using TfLiteInterpreterOptionsSetNumThreads_t = void (*)(TfLiteInterpreterOptions *options, int32_t num_threads); +using TfLiteInterpreterOptionsAddDelegate_t = void (*)(TfLiteInterpreterOptions *options, TfLiteDelegate *delegate); +using TfLiteInterpreterOptionsSetErrorReporter_t = void (*)(TfLiteInterpreterOptions *options, void (*reporter)(void *user_data, const char *format, va_list args), void *user_data); +using TfLiteInterpreterCreate_t = TfLiteInterpreter *(*)(const TfLiteModel *model, const TfLiteInterpreterOptions *optional_options); +using TfLiteInterpreterDelete_t = void (*)(TfLiteInterpreter *interpreter); +using TfLiteInterpreterGetInputTensorCount_t = int32_t (*)(const TfLiteInterpreter *interpreter); +using TfLiteInterpreterGetInputTensor_t = TfLiteTensor *(*)(const TfLiteInterpreter *interpreter, int32_t input_index); +using TfLiteInterpreterResizeInputTensor_t = TfLiteStatus (*)(TfLiteInterpreter *interpreter, int32_t input_index, const int *input_dims, int32_t input_dims_size); +using TfLiteInterpreterAllocateTensors_t = TfLiteStatus (*)(TfLiteInterpreter *interpreter); +using TfLiteInterpreterInvoke_t = TfLiteStatus (*)(TfLiteInterpreter *interpreter); +using TfLiteInterpreterGetOutputTensorCount_t = int32_t (*)(const TfLiteInterpreter *interpreter); +using TfLiteInterpreterGetOutputTensor_t = const TfLiteTensor *(*)(const TfLiteInterpreter *interpreter, int32_t output_index); +using TfLiteTensorType_t = TfLiteType (*)(const TfLiteTensor *tensor); +using TfLiteTensorNumDims_t = int32_t (*)(const TfLiteTensor *tensor); +using TfLiteTensorDim_t = int32_t (*)(const TfLiteTensor *tensor, int32_t dim_index); +using TfLiteTensorByteSize_t = size_t (*)(const TfLiteTensor *tensor); +using TfLiteTensorData_t = void *(*)(const TfLiteTensor *tensor); +using TfLiteTensorName_t = const char *(*)(const TfLiteTensor *tensor); +using TfLiteTensorQuantizationParams_t = TfLiteQuantizationParams (*)(const TfLiteTensor *tensor); +using TfLiteTensorCopyFromBuffer_t = TfLiteStatus (*)(TfLiteTensor *tensor, const void *input_data, size_t input_data_size); +using TfLiteTensorCopyToBuffer_t = TfLiteStatus (*)(const TfLiteTensor *output_tensor, void *output_data, size_t output_data_size); -TfLiteModelCreate_t TfLiteModelCreate; -TfLiteModelCreateFromFile_t TfLiteModelCreateFromFile; -TfLiteModelDelete_t TfLiteModelDelete; -TfLiteInterpreterOptionsCreate_t TfLiteInterpreterOptionsCreate; -TfLiteInterpreterOptionsDelete_t TfLiteInterpreterOptionsDelete; -TfLiteInterpreterOptionsSetNumThreads_t TfLiteInterpreterOptionsSetNumThreads; -TfLiteInterpreterOptionsAddDelegate_t TfLiteInterpreterOptionsAddDelegate; +TfLiteModelCreate_t TfLiteModelCreate; +TfLiteModelCreateFromFile_t TfLiteModelCreateFromFile; +TfLiteModelDelete_t TfLiteModelDelete; +TfLiteInterpreterOptionsCreate_t TfLiteInterpreterOptionsCreate; +TfLiteInterpreterOptionsDelete_t TfLiteInterpreterOptionsDelete; +TfLiteInterpreterOptionsSetNumThreads_t TfLiteInterpreterOptionsSetNumThreads; +TfLiteInterpreterOptionsAddDelegate_t TfLiteInterpreterOptionsAddDelegate; TfLiteInterpreterOptionsSetErrorReporter_t TfLiteInterpreterOptionsSetErrorReporter; -TfLiteInterpreterCreate_t TfLiteInterpreterCreate; -TfLiteInterpreterDelete_t TfLiteInterpreterDelete; -TfLiteInterpreterGetInputTensorCount_t TfLiteInterpreterGetInputTensorCount; -TfLiteInterpreterGetInputTensor_t TfLiteInterpreterGetInputTensor; -TfLiteInterpreterResizeInputTensor_t TfLiteInterpreterResizeInputTensor; -TfLiteInterpreterAllocateTensors_t TfLiteInterpreterAllocateTensors; -TfLiteInterpreterInvoke_t TfLiteInterpreterInvoke; -TfLiteInterpreterGetOutputTensorCount_t TfLiteInterpreterGetOutputTensorCount; -TfLiteInterpreterGetOutputTensor_t TfLiteInterpreterGetOutputTensor; -TfLiteTensorType_t TfLiteTensorType; -TfLiteTensorNumDims_t TfLiteTensorNumDims; -TfLiteTensorDim_t TfLiteTensorDim; -TfLiteTensorByteSize_t TfLiteTensorByteSize; -TfLiteTensorData_t TfLiteTensorData; -TfLiteTensorName_t TfLiteTensorName; -TfLiteTensorQuantizationParams_t TfLiteTensorQuantizationParams; -TfLiteTensorCopyFromBuffer_t TfLiteTensorCopyFromBuffer; -TfLiteTensorCopyToBuffer_t TfLiteTensorCopyToBuffer; +TfLiteInterpreterCreate_t TfLiteInterpreterCreate; +TfLiteInterpreterDelete_t TfLiteInterpreterDelete; +TfLiteInterpreterGetInputTensorCount_t TfLiteInterpreterGetInputTensorCount; +TfLiteInterpreterGetInputTensor_t TfLiteInterpreterGetInputTensor; +TfLiteInterpreterResizeInputTensor_t TfLiteInterpreterResizeInputTensor; +TfLiteInterpreterAllocateTensors_t TfLiteInterpreterAllocateTensors; +TfLiteInterpreterInvoke_t TfLiteInterpreterInvoke; +TfLiteInterpreterGetOutputTensorCount_t TfLiteInterpreterGetOutputTensorCount; +TfLiteInterpreterGetOutputTensor_t TfLiteInterpreterGetOutputTensor; +TfLiteTensorType_t TfLiteTensorType; +TfLiteTensorNumDims_t TfLiteTensorNumDims; +TfLiteTensorDim_t TfLiteTensorDim; +TfLiteTensorByteSize_t TfLiteTensorByteSize; +TfLiteTensorData_t TfLiteTensorData; +TfLiteTensorName_t TfLiteTensorName; +TfLiteTensorQuantizationParams_t TfLiteTensorQuantizationParams; +TfLiteTensorCopyFromBuffer_t TfLiteTensorCopyFromBuffer; +TfLiteTensorCopyToBuffer_t TfLiteTensorCopyToBuffer; bool tensorflowlite_init() { static ion::bb::dnn::DynamicModule dm("tensorflowlite_c"); @@ -76,7 +76,7 @@ bool tensorflowlite_init() { } #define RESOLVE_SYMBOL(SYM_NAME) \ - SYM_NAME = dm.get_symbol(#SYM_NAME); \ + SYM_NAME = dm.get_symbol(#SYM_NAME); \ if (SYM_NAME == nullptr) { \ throw std::runtime_error( \ #SYM_NAME " is unavailable on your edgetpu DSO"); \ diff --git a/src/bb/dnn/tensorflowlite_types.h b/src/bb/dnn/tensorflowlite_types.h index 53e9f5cf..94ceb5aa 100644 --- a/src/bb/dnn/tensorflowlite_types.h +++ b/src/bb/dnn/tensorflowlite_types.h @@ -45,9 +45,9 @@ extern "C" { #endif // __cplusplus typedef enum TfLiteStatus { - kTfLiteOk = 0, - kTfLiteError = 1, - kTfLiteDelegateError = 2 + kTfLiteOk = 0, + kTfLiteError = 1, + kTfLiteDelegateError = 2 } TfLiteStatus; // The list of external context types known to TF Lite. This list exists solely @@ -55,11 +55,11 @@ typedef enum TfLiteStatus { // need. Access to the external contexts is controlled by one of the // corresponding support files. typedef enum TfLiteExternalContextType { - kTfLiteEigenContext = 0, // include eigen_support.h to use. - kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. - kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. - kTfLiteCpuBackendContext = 3, // include cpu_backend_context.h to use. - kTfLiteMaxExternalContexts = 4 + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteCpuBackendContext = 3, // include cpu_backend_context.h to use. + kTfLiteMaxExternalContexts = 4 } TfLiteExternalContextType; // Forward declare so dependent structs and methods can reference these types @@ -74,8 +74,8 @@ struct TfLiteRegistration; // refresh them if configurations like the number of recommended threads // change. typedef struct TfLiteExternalContext { - TfLiteExternalContextType type; - TfLiteStatus (*Refresh)(struct TfLiteContext* context); + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext *context); } TfLiteExternalContext; #define kTfLiteOptionalTensor (-1) @@ -83,15 +83,15 @@ typedef struct TfLiteExternalContext { // Fixed size list of integers. Used for dimensions and inputs/outputs tensor // indices typedef struct TfLiteIntArray { - int size; + int size; // gcc 6.1+ have a bug where flexible members aren't properly handled // https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c #if (!defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ __GNUC_MINOR__ >= 1) || \ defined(HEXAGON) - int data[0]; + int data[0]; #else - int data[]; + int data[]; #endif } TfLiteIntArray; @@ -102,36 +102,36 @@ int TfLiteIntArrayGetSizeInBytes(int size); #ifndef TF_LITE_STATIC_MEMORY // Create a array of a given `size` (uninitialized entries). // This returns a pointer, that you must free using TfLiteIntArrayFree(). -TfLiteIntArray* TfLiteIntArrayCreate(int size); +TfLiteIntArray *TfLiteIntArrayCreate(int size); #endif // Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise. -int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b); +int TfLiteIntArrayEqual(const TfLiteIntArray *a, const TfLiteIntArray *b); // Check if an intarray equals an array. Returns 1 if equals, 0 otherwise. -int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size, +int TfLiteIntArrayEqualsArray(const TfLiteIntArray *a, int b_size, const int b_data[]); #ifndef TF_LITE_STATIC_MEMORY // Create a copy of an array passed as `src`. // You are expected to free memory with TfLiteIntArrayFree -TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src); +TfLiteIntArray *TfLiteIntArrayCopy(const TfLiteIntArray *src); // Free memory of array `a`. -void TfLiteIntArrayFree(TfLiteIntArray* a); +void TfLiteIntArrayFree(TfLiteIntArray *a); #endif // TF_LITE_STATIC_MEMORY // Fixed size list of floats. Used for per-channel quantization. typedef struct TfLiteFloatArray { - int size; + int size; // gcc 6.1+ have a bug where flexible members aren't properly handled // https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c // This also applies to the toolchain used for Qualcomm Hexagon DSPs. #if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ __GNUC_MINOR__ >= 1 - float data[0]; + float data[0]; #else - float data[]; + float data[]; #endif } TfLiteFloatArray; @@ -142,10 +142,10 @@ int TfLiteFloatArrayGetSizeInBytes(int size); #ifndef TF_LITE_STATIC_MEMORY // Create a array of a given `size` (uninitialized entries). // This returns a pointer, that you must free using TfLiteFloatArrayFree(). -TfLiteFloatArray* TfLiteFloatArrayCreate(int size); +TfLiteFloatArray *TfLiteFloatArrayCreate(int size); // Free memory of array `a`. -void TfLiteFloatArrayFree(TfLiteFloatArray* a); +void TfLiteFloatArrayFree(TfLiteFloatArray *a); #endif // TF_LITE_STATIC_MEMORY // Since we must not depend on any libraries, define a minimal subset of @@ -156,17 +156,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); // calling the context->ReportError function directly, so that message strings // can be stripped out if the binary size needs to be severely optimized. #ifndef TF_LITE_STRIP_ERROR_STRINGS -#define TF_LITE_KERNEL_LOG(context, ...) \ - do { \ - (context)->ReportError((context), __VA_ARGS__); \ - } while (false) - -#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) \ - do { \ - if ((context) != nullptr) { \ - (context)->ReportError((context), __VA_ARGS__); \ - } \ - } while (false) +#define TF_LITE_KERNEL_LOG(context, ...) \ + do { \ + (context)->ReportError((context), __VA_ARGS__); \ + } while (false) + +#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) \ + do { \ + if ((context) != nullptr) { \ + (context)->ReportError((context), __VA_ARGS__); \ + } \ + } while (false) #else // TF_LITE_STRIP_ERROR_STRINGS #define TF_LITE_KERNEL_LOG(context, ...) #define TF_LITE_MAYBE_KERNEL_LOG(context, ...) @@ -174,110 +174,110 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); // Check whether value is true, and if not return kTfLiteError from // the current function (and report the error string msg). -#define TF_LITE_ENSURE_MSG(context, value, msg) \ - do { \ - if (!(value)) { \ - TF_LITE_KERNEL_LOG((context), __FILE__ " " msg); \ - return kTfLiteError; \ - } \ - } while (0) +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + TF_LITE_KERNEL_LOG((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) // Check whether the value `a` is true, and if not return kTfLiteError from // the current function, while also reporting the location of the error. -#define TF_LITE_ENSURE(context, a) \ - do { \ - if (!(a)) { \ - TF_LITE_KERNEL_LOG((context), "%s:%d %s was not true.", __FILE__, \ - __LINE__, #a); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_STATUS(a) \ - do { \ - const TfLiteStatus s = (a); \ - if (s != kTfLiteOk) { \ - return s; \ - } \ - } while (0) +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + const TfLiteStatus s = (a); \ + if (s != kTfLiteOk) { \ + return s; \ + } \ + } while (0) // Check whether the value `a == b` is true, and if not return kTfLiteError from // the current function, while also reporting the location of the error. // `a` and `b` may be evaluated more than once, so no side effects or // extremely expensive computations should be done. // NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes. -#define TF_LITE_ENSURE_EQ(context, a, b) \ - do { \ - if ((a) != (b)) { \ - TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ - __LINE__, #a, #b, (a), (b)); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_TYPES_EQ(context, a, b) \ - do { \ - if ((a) != (b)) { \ - TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%s != %s)", __FILE__, \ - __LINE__, #a, #b, TfLiteTypeGetName(a), \ - TfLiteTypeGetName(b)); \ - return kTfLiteError; \ - } \ - } while (0) +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_TYPES_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%s != %s)", __FILE__, \ + __LINE__, #a, #b, TfLiteTypeGetName(a), \ + TfLiteTypeGetName(b)); \ + return kTfLiteError; \ + } \ + } while (0) #define TF_LITE_ENSURE_OK(context, status) \ - do { \ - const TfLiteStatus s = (status); \ - if ((s) != kTfLiteOk) { \ - return s; \ - } \ - } while (0) + do { \ + const TfLiteStatus s = (status); \ + if ((s) != kTfLiteOk) { \ + return s; \ + } \ + } while (0) // Single-precision complex data type compatible with the C99 definition. typedef struct TfLiteComplex64 { - float re, im; // real and imaginary parts, respectively. + float re, im; // real and imaginary parts, respectively. } TfLiteComplex64; // Half precision data type compatible with the C99 definition. typedef struct TfLiteFloat16 { - uint16_t data; + uint16_t data; } TfLiteFloat16; // Types supported by tensor typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, - kTfLiteInt8 = 9, - kTfLiteFloat16 = 10, - kTfLiteFloat64 = 11, + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, + kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, } TfLiteType; // Return the name of a given type, for error reporting purposes. -const char* TfLiteTypeGetName(TfLiteType type); +const char *TfLiteTypeGetName(TfLiteType type); // SupportedQuantizationTypes. typedef enum TfLiteQuantizationType { - // No quantization. - kTfLiteNoQuantization = 0, - // Affine quantization (with support for per-channel quantization). - // Corresponds to TfLiteAffineQuantization. - kTfLiteAffineQuantization = 1, + // No quantization. + kTfLiteNoQuantization = 0, + // Affine quantization (with support for per-channel quantization). + // Corresponds to TfLiteAffineQuantization. + kTfLiteAffineQuantization = 1, } TfLiteQuantizationType; // Structure specifying the quantization used by the tensor, if-any. typedef struct TfLiteQuantization { - // The type of quantization held by params. - TfLiteQuantizationType type; - // Holds a reference to one of the quantization param structures specified - // below. - void* params; + // The type of quantization held by params. + TfLiteQuantizationType type; + // Holds a reference to one of the quantization param structures specified + // below. + void *params; } TfLiteQuantization; // Legacy. Will be deprecated in favor of TfLiteAffineQuantization. @@ -287,8 +287,8 @@ typedef struct TfLiteQuantization { // back to float using: // real_value = scale * (quantized_value - zero_point) typedef struct TfLiteQuantizationParams { - float scale; - int32_t zero_point; + float scale; + int32_t zero_point; } TfLiteQuantizationParams; // Parameters for asymmetric quantization across a dimension (i.e per output @@ -299,29 +299,29 @@ typedef struct TfLiteQuantizationParams { // converted back to float using: // real_value = scale * (quantized_value - zero_point) typedef struct TfLiteAffineQuantization { - TfLiteFloatArray* scale; - TfLiteIntArray* zero_point; - int32_t quantized_dimension; + TfLiteFloatArray *scale; + TfLiteIntArray *zero_point; + int32_t quantized_dimension; } TfLiteAffineQuantization; /* A union of pointers that points to memory for a given tensor. */ typedef union TfLitePtrUnion { - /* Do not access these members directly, if possible, use - * GetTensorData(tensor) instead, otherwise only access .data, as other - * members are deprecated. */ - int32_t* i32; - int64_t* i64; - float* f; - TfLiteFloat16* f16; - char* raw; - const char* raw_const; - uint8_t* uint8; - bool* b; - int16_t* i16; - TfLiteComplex64* c64; - int8_t* int8; - /* Only use this member. */ - void* data; + /* Do not access these members directly, if possible, use + * GetTensorData(tensor) instead, otherwise only access .data, as other + * members are deprecated. */ + int32_t *i32; + int64_t *i64; + float *f; + TfLiteFloat16 *f16; + char *raw; + const char *raw_const; + uint8_t *uint8; + bool *b; + int16_t *i16; + TfLiteComplex64 *c64; + int8_t *int8; + /* Only use this member. */ + void *data; } TfLitePtrUnion; // Memory allocation strategies. @@ -335,110 +335,110 @@ typedef union TfLitePtrUnion { // useful for tensors that can be computed during prepare and treated // as constant inputs for downstream ops (also in prepare). typedef enum TfLiteAllocationType { - kTfLiteMemNone = 0, - kTfLiteMmapRo, - kTfLiteArenaRw, - kTfLiteArenaRwPersistent, - kTfLiteDynamic, - kTfLitePersistentRo, + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, + kTfLitePersistentRo, } TfLiteAllocationType; // The delegates should use zero or positive integers to represent handles. // -1 is reserved from unallocated status. typedef int TfLiteBufferHandle; enum { - kTfLiteNullBufferHandle = -1, + kTfLiteNullBufferHandle = -1, }; // Storage format of each dimension in a sparse tensor. typedef enum TfLiteDimensionType { - kTfLiteDimDense = 0, - kTfLiteDimSparseCSR, + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, } TfLiteDimensionType; // Metadata to encode each dimension in a sparse tensor. typedef struct TfLiteDimensionMetadata { - TfLiteDimensionType format; - int dense_size; - TfLiteIntArray* array_segments; - TfLiteIntArray* array_indices; + TfLiteDimensionType format; + int dense_size; + TfLiteIntArray *array_segments; + TfLiteIntArray *array_indices; } TfLiteDimensionMetadata; // Parameters used to encode a sparse tensor. For detailed explanation of each // field please refer to lite/schema/schema.fbs. typedef struct TfLiteSparsity { - TfLiteIntArray* traversal_order; - TfLiteIntArray* block_map; - TfLiteDimensionMetadata* dim_metadata; - int dim_metadata_size; + TfLiteIntArray *traversal_order; + TfLiteIntArray *block_map; + TfLiteDimensionMetadata *dim_metadata; + int dim_metadata_size; } TfLiteSparsity; // An tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). #ifndef TF_LITE_STATIC_MEMORY typedef struct TfLiteTensor { - // The data type specification for data stored in `data`. This affects - // what member of `data` union should be used. - TfLiteType type; - // A union of data pointers. The appropriate type should be used for a typed - // tensor based on `type`. - TfLitePtrUnion data; - // A pointer to a structure representing the dimensionality interpretation - // that the buffer should have. NOTE: the product of elements of `dims` - // and the element datatype size should be equal to `bytes` below. - TfLiteIntArray* dims; - // Quantization information. - TfLiteQuantizationParams params; - // How memory is mapped - // kTfLiteMmapRo: Memory mapped read only. - // i.e. weights - // kTfLiteArenaRw: Arena allocated read write memory - // (i.e. temporaries, outputs). - TfLiteAllocationType allocation_type; - // The number of bytes required to store the data of this Tensor. I.e. - // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if - // type is kTfLiteFloat32 and dims = {3, 2} then - // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. - size_t bytes; - - // An opaque pointer to a tflite::MMapAllocation - const void* allocation; - - // Null-terminated name of this tensor. - const char* name; - - // The delegate which knows how to handle `buffer_handle`. - // WARNING: This is an experimental interface that is subject to change. - struct TfLiteDelegate* delegate; - - // An integer buffer handle that can be handled by `delegate`. - // The value is valid only when delegate is not null. - // WARNING: This is an experimental interface that is subject to change. - TfLiteBufferHandle buffer_handle; - - // If the delegate uses its own buffer (e.g. GPU memory), the delegate is - // responsible to set data_is_stale to true. - // `delegate->CopyFromBufferHandle` can be called to copy the data from - // delegate buffer. - // WARNING: This is an // experimental interface that is subject to change. - bool data_is_stale; - - // True if the tensor is a variable. - bool is_variable; - - // Quantization information. Replaces params field above. - TfLiteQuantization quantization; - - // Parameters used to encode a sparse tensor. - // This is optional. The field is NULL if a tensor is dense. - // WARNING: This is an experimental interface that is subject to change. - TfLiteSparsity* sparsity; - - // Optional. Encodes shapes with unknown dimensions with -1. This field is - // only populated when unknown dimensions exist in a read-write tensor (i.e. - // an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and - // `dims_signature` contains [1, -1, -1, 3]). - const TfLiteIntArray* dims_signature; + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray *dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void *allocation; + + // Null-terminated name of this tensor. + const char *name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + struct TfLiteDelegate *delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; + + // Quantization information. Replaces params field above. + TfLiteQuantization quantization; + + // Parameters used to encode a sparse tensor. + // This is optional. The field is NULL if a tensor is dense. + // WARNING: This is an experimental interface that is subject to change. + TfLiteSparsity *sparsity; + + // Optional. Encodes shapes with unknown dimensions with -1. This field is + // only populated when unknown dimensions exist in a read-write tensor (i.e. + // an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and + // `dims_signature` contains [1, -1, -1, 3]). + const TfLiteIntArray *dims_signature; } TfLiteTensor; #else // Specific reduced TfLiteTensor struct for TF Micro runtime. This struct @@ -448,104 +448,104 @@ typedef struct TfLiteTensor { // // NOTE: This flag is opt-in only at compile time. typedef struct TfLiteTensor { - // TODO(b/155784997): Consider consolidating these quantization fields: - // Quantization information. Replaces params field above. - TfLiteQuantization quantization; - - // Quantization information. - TfLiteQuantizationParams params; - - // A union of data pointers. The appropriate type should be used for a typed - // tensor based on `type`. - TfLitePtrUnion data; - - // A pointer to a structure representing the dimensionality interpretation - // that the buffer should have. NOTE: the product of elements of `dims` - // and the element datatype size should be equal to `bytes` below. - TfLiteIntArray* dims; - - // The number of bytes required to store the data of this Tensor. I.e. - // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if - // type is kTfLiteFloat32 and dims = {3, 2} then - // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. - size_t bytes; - - // The data type specification for data stored in `data`. This affects - // what member of `data` union should be used. - TfLiteType type; - - // How memory is mapped - // kTfLiteMmapRo: Memory mapped read only. - // i.e. weights - // kTfLiteArenaRw: Arena allocated read write memory - // (i.e. temporaries, outputs). - TfLiteAllocationType allocation_type; - - // True if the tensor is a variable. - bool is_variable; + // TODO(b/155784997): Consider consolidating these quantization fields: + // Quantization information. Replaces params field above. + TfLiteQuantization quantization; + + // Quantization information. + TfLiteQuantizationParams params; + + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray *dims; + + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + + // True if the tensor is a variable. + bool is_variable; } TfLiteTensor; #endif // TF_LITE_STATIC_MEMORY #ifndef TF_LITE_STATIC_MEMORY // Free data memory of tensor `t`. -void TfLiteTensorDataFree(TfLiteTensor* t); +void TfLiteTensorDataFree(TfLiteTensor *t); // Free quantization data. -void TfLiteQuantizationFree(TfLiteQuantization* quantization); +void TfLiteQuantizationFree(TfLiteQuantization *quantization); // Free sparsity parameters. -void TfLiteSparsityFree(TfLiteSparsity* sparsity); +void TfLiteSparsityFree(TfLiteSparsity *sparsity); // Free memory of tensor `t`. -void TfLiteTensorFree(TfLiteTensor* t); +void TfLiteTensorFree(TfLiteTensor *t); // Set all of a tensor's fields (and free any previously allocated data). -void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, - TfLiteQuantizationParams quantization, char* buffer, +void TfLiteTensorReset(TfLiteType type, const char *name, TfLiteIntArray *dims, + TfLiteQuantizationParams quantization, char *buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, bool is_variable, - TfLiteTensor* tensor); + const void *allocation, bool is_variable, + TfLiteTensor *tensor); // Resize the allocated data of a (dynamic) tensor. Tensors with allocation // types other than kTfLiteDynamic will be ignored. -void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor *tensor); #endif // TF_LITE_STATIC_MEMORY // A structure representing an instance of a node. // This structure only exhibits the inputs, outputs and user defined data, not // other features like the type. typedef struct TfLiteNode { - // Inputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* inputs; + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray *inputs; - // Outputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* outputs; + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray *outputs; - // intermediate tensors to this node expressed as indices into the simulator's - // tensors. - TfLiteIntArray* intermediates; + // intermediate tensors to this node expressed as indices into the simulator's + // tensors. + TfLiteIntArray *intermediates; - // Temporary tensors uses during the computations. This usually contains no - // tensors, but ops are allowed to change that if they need scratch space of - // any sort. - TfLiteIntArray* temporaries; + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray *temporaries; - // Opaque data provided by the node implementer through `Registration.init`. - void* user_data; + // Opaque data provided by the node implementer through `Registration.init`. + void *user_data; - // Opaque data provided to the node if the node is a builtin. This is usually - // a structure defined in builtin_op_data.h - void* builtin_data; + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void *builtin_data; - // Custom initial data. This is the opaque data provided in the flatbuffer. - // WARNING: This is an experimental interface that is subject to change. - const void* custom_initial_data; - int custom_initial_data_size; + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void *custom_initial_data; + int custom_initial_data_size; - // The pointer to the delegate. This is non-null only when the node is - // created by calling `interpreter.ModifyGraphWithDelegate`. - // WARNING: This is an experimental interface that is subject to change. - struct TfLiteDelegate* delegate; + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + struct TfLiteDelegate *delegate; } TfLiteNode; // WARNING: This is an experimental interface that is subject to change. @@ -556,266 +556,266 @@ typedef struct TfLiteNode { // // See also the `CreateDelegateParams` function in `interpreter.cc` details. typedef struct TfLiteDelegateParams { - struct TfLiteDelegate* delegate; - TfLiteIntArray* nodes_to_replace; - TfLiteIntArray* input_tensors; - TfLiteIntArray* output_tensors; + struct TfLiteDelegate *delegate; + TfLiteIntArray *nodes_to_replace; + TfLiteIntArray *input_tensors; + TfLiteIntArray *output_tensors; } TfLiteDelegateParams; typedef struct TfLiteContext { - // Number of tensors in the context. - size_t tensors_size; - - // The execution plan contains a list of the node indices in execution - // order. execution_plan->size is the current number of nodes. And, - // execution_plan->data[0] is the first node that needs to be run. - // TfLiteDelegates can traverse the current execution plan by iterating - // through each member of this array and using GetNodeAndRegistration() to - // access details about a node. i.e. - // TfLiteIntArray* execution_plan; - // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); - // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { - // int node_index = execution_plan->data[exec_index]; - // TfLiteNode* node; - // TfLiteRegistration* reg; - // context->GetNodeAndRegistration(context, node_index, &node, ®); - // } - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, - TfLiteIntArray** execution_plan); - - // An array of tensors in the interpreter context (of length `tensors_size`) - TfLiteTensor* tensors; - - // opaque full context ptr (an opaque c++ data structure) - void* impl_; - - // Request memory pointer be resized. Updates dimensions on the tensor. - // NOTE: ResizeTensor takes ownership of newSize. - TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, - TfLiteIntArray* new_size); - // Request that an error be reported with format string msg. - void (*ReportError)(struct TfLiteContext*, const char* msg, ...); - - // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If - // non-null, the value pointed to by `first_new_tensor_index` will be set to - // the index of the first new tensor. - TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, - int* first_new_tensor_index); - - // Get a Tensor node by node_index. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetNodeAndRegistration)( - struct TfLiteContext*, int node_index, TfLiteNode** node, - struct TfLiteRegistration** registration); - - // Replace ops with one or more stub delegate operations. This function - // does not take ownership of `nodes_to_replace`. - TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)( - struct TfLiteContext*, struct TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate); - - // Number of threads that are recommended to subsystems like gemmlowp and - // eigen. - int recommended_num_threads; - - // Access external contexts by type. - // WARNING: This is an experimental interface that is subject to change. - TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, - TfLiteExternalContextType); - // Set the value of a external context. Does not take ownership of the - // pointer. - // WARNING: This is an experimental interface that is subject to change. - void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, - TfLiteExternalContext*); - - // Flag for allowing float16 precision for FP32 calculation. - // default: false. - // WARNING: This is an experimental API and subject to change. - bool allow_fp32_relax_to_fp16; - - // Pointer to the op-level profiler, if set; nullptr otherwise. - void* profiler; - - // Allocate persistent buffer which has the same life time as the interpreter. - // The memory is allocated from heap for TFL, and from tail in TFLM. - // If *ptr is not nullptr, the pointer will be reallocated. - // This method is only available in Prepare stage. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*AllocatePersistentBuffer)(struct TfLiteContext* ctx, - size_t bytes, void** ptr); - - // Allocate a buffer which will be deallocated right after invoke phase. - // The memory is allocated from heap in TFL, and from volatile arena in TFLM. - // This method is only available in invoke stage. - // NOTE: If possible use RequestScratchBufferInArena method to avoid memory - // allocation during inference time. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*AllocateBufferForEval)(struct TfLiteContext* ctx, size_t bytes, - void** ptr); - - // Request a scratch buffer in the arena through static memory planning. - // This method is only available in Prepare stage and the buffer is allocated - // by the interpreter between Prepare and Eval stage. In Eval stage, - // GetScratchBuffer API can be used to fetch the address. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*RequestScratchBufferInArena)(struct TfLiteContext* ctx, - size_t bytes, int* buffer_idx); - - // Get the scratch buffer pointer. - // This method is only available in Eval stage. - // WARNING: This is an experimental interface that is subject to change. - void* (*GetScratchBuffer)(struct TfLiteContext* ctx, int buffer_idx); - - // Resize the memory pointer of the `tensor`. This method behaves the same as - // `ResizeTensor`, except that it makes a copy of the shape array internally - // so the shape array could be deallocated right afterwards. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext* ctx, - TfLiteTensor* tensor, int dims, - const int* shape); - - // This method provides a preview of post-delegation partitioning. Each - // TfLiteDelegateParams in the referenced array corresponds to one instance of - // the delegate kernel. - // Example usage: - // - // TfLiteIntArray* nodes_to_replace = ...; - // TfLiteDelegateParams* params_array; - // int num_partitions = 0; - // TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning( - // context, delegate, nodes_to_replace, ¶ms_array, &num_partitions)); - // for (int idx = 0; idx < num_partitions; idx++) { - // const auto& partition_params = params_array[idx]; - // ... - // } - // - // NOTE: The context owns the memory referenced by partition_params_array. It - // will be cleared with another call to PreviewDelegateParitioning, or after - // TfLiteDelegateParams::Prepare returns. - // - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*PreviewDelegatePartitioning)( - struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, - TfLiteDelegateParams** partition_params_array, int* num_partitions); + // Number of tensors in the context. + size_t tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext *context, + TfLiteIntArray **execution_plan); + + // An array of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor *tensors; + + // opaque full context ptr (an opaque c++ data structure) + void *impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext *, TfLiteTensor *tensor, + TfLiteIntArray *new_size); + // Request that an error be reported with format string msg. + void (*ReportError)(struct TfLiteContext *, const char *msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext *, int tensors_to_add, + int *first_new_tensor_index); + + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)( + struct TfLiteContext *, int node_index, TfLiteNode **node, + struct TfLiteRegistration **registration); + + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. + TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)( + struct TfLiteContext *, struct TfLiteRegistration registration, + const TfLiteIntArray *nodes_to_replace, struct TfLiteDelegate *delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; + + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext *(*GetExternalContext)(struct TfLiteContext *, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext *, TfLiteExternalContextType, + TfLiteExternalContext *); + + // Flag for allowing float16 precision for FP32 calculation. + // default: false. + // WARNING: This is an experimental API and subject to change. + bool allow_fp32_relax_to_fp16; + + // Pointer to the op-level profiler, if set; nullptr otherwise. + void *profiler; + + // Allocate persistent buffer which has the same life time as the interpreter. + // The memory is allocated from heap for TFL, and from tail in TFLM. + // If *ptr is not nullptr, the pointer will be reallocated. + // This method is only available in Prepare stage. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*AllocatePersistentBuffer)(struct TfLiteContext *ctx, + size_t bytes, void **ptr); + + // Allocate a buffer which will be deallocated right after invoke phase. + // The memory is allocated from heap in TFL, and from volatile arena in TFLM. + // This method is only available in invoke stage. + // NOTE: If possible use RequestScratchBufferInArena method to avoid memory + // allocation during inference time. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*AllocateBufferForEval)(struct TfLiteContext *ctx, size_t bytes, + void **ptr); + + // Request a scratch buffer in the arena through static memory planning. + // This method is only available in Prepare stage and the buffer is allocated + // by the interpreter between Prepare and Eval stage. In Eval stage, + // GetScratchBuffer API can be used to fetch the address. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*RequestScratchBufferInArena)(struct TfLiteContext *ctx, + size_t bytes, int *buffer_idx); + + // Get the scratch buffer pointer. + // This method is only available in Eval stage. + // WARNING: This is an experimental interface that is subject to change. + void *(*GetScratchBuffer)(struct TfLiteContext *ctx, int buffer_idx); + + // Resize the memory pointer of the `tensor`. This method behaves the same as + // `ResizeTensor`, except that it makes a copy of the shape array internally + // so the shape array could be deallocated right afterwards. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext *ctx, + TfLiteTensor *tensor, int dims, + const int *shape); + + // This method provides a preview of post-delegation partitioning. Each + // TfLiteDelegateParams in the referenced array corresponds to one instance of + // the delegate kernel. + // Example usage: + // + // TfLiteIntArray* nodes_to_replace = ...; + // TfLiteDelegateParams* params_array; + // int num_partitions = 0; + // TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning( + // context, delegate, nodes_to_replace, ¶ms_array, &num_partitions)); + // for (int idx = 0; idx < num_partitions; idx++) { + // const auto& partition_params = params_array[idx]; + // ... + // } + // + // NOTE: The context owns the memory referenced by partition_params_array. It + // will be cleared with another call to PreviewDelegateParitioning, or after + // TfLiteDelegateParams::Prepare returns. + // + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*PreviewDelegatePartitioning)( + struct TfLiteContext *context, const TfLiteIntArray *nodes_to_replace, + TfLiteDelegateParams **partition_params_array, int *num_partitions); } TfLiteContext; typedef struct TfLiteRegistration { - // Initializes the op from serialized data. - // If a built-in op: - // `buffer` is the op's params data (TfLiteLSTMParams*). - // `length` is zero. - // If custom op: - // `buffer` is the op's `custom_options`. - // `length` is the size of the buffer. - // - // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer - // or an instance of a struct). - // - // The returned pointer will be stored with the node in the `user_data` field, - // accessible within prepare and invoke functions below. - // NOTE: if the data is already in the desired format, simply implement this - // function to return `nullptr` and implement the free function to be a no-op. - void* (*init)(TfLiteContext* context, const char* buffer, size_t length); - - // The pointer `buffer` is the data previously returned by an init invocation. - void (*free)(TfLiteContext* context, void* buffer); - - // prepare is called when the inputs this node depends on have been resized. - // context->ResizeTensor() can be called to request output tensors to be - // resized. - // - // Returns kTfLiteOk on success. - TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); - - // Execute the node (should read node->inputs and output to node->outputs). - // Returns kTfLiteOk on success. - TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); - - // profiling_string is called during summarization of profiling information - // in order to group executions together. Providing a value here will cause a - // given op to appear multiple times is the profiling report. This is - // particularly useful for custom ops that can perform significantly - // different calculations depending on their `user-data`. - const char* (*profiling_string)(const TfLiteContext* context, - const TfLiteNode* node); - - // Builtin codes. If this kernel refers to a builtin this is the code - // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. - // Note: It is the responsibility of the registration binder to set this - // properly. - int32_t builtin_code; - - // Custom op name. If the op is a builtin, this will be null. - // Note: It is the responsibility of the registration binder to set this - // properly. - // WARNING: This is an experimental interface that is subject to change. - const char* custom_name; - - // The version of the op. - // Note: It is the responsibility of the registration binder to set this - // properly. - int version; + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void *(*init)(TfLiteContext *context, const char *buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext *context, void *buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext *context, TfLiteNode *node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext *context, TfLiteNode *node); + + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char *(*profiling_string)(const TfLiteContext *context, + const TfLiteNode *node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. + int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. + // WARNING: This is an experimental interface that is subject to change. + const char *custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the // values should be 1, 2, 4, 8, ...etc. typedef enum TfLiteDelegateFlags { - kTfLiteDelegateFlagsNone = 0, - // The flag is set if the delegate can handle dynamic sized tensors. - // For example, the output shape of a `Resize` op with non-constant shape - // can only be inferred when the op is invoked. - // In this case, the Delegate is responsible for calling - // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling - // `ResizeTensor` when invoking the op. - // - // If the delegate isn't capable to handle dynamic tensors, this flag need - // to be set to false. - kTfLiteDelegateFlagsAllowDynamicTensors = 1 + kTfLiteDelegateFlagsNone = 0, + // The flag is set if the delegate can handle dynamic sized tensors. + // For example, the output shape of a `Resize` op with non-constant shape + // can only be inferred when the op is invoked. + // In this case, the Delegate is responsible for calling + // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling + // `ResizeTensor` when invoking the op. + // + // If the delegate isn't capable to handle dynamic tensors, this flag need + // to be set to false. + kTfLiteDelegateFlagsAllowDynamicTensors = 1 } TfLiteDelegateFlags; // WARNING: This is an experimental interface that is subject to change. typedef struct TfLiteDelegate { - // Data that delegate needs to identify itself. This data is owned by the - // delegate. The delegate is owned in the user code, so the delegate is - // responsible for doing this when it is destroyed. - void* data_; - - // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the - // delegate a view of the current graph through TfLiteContext*. It typically - // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels() - // to ask the TensorFlow lite runtime to create macro-nodes to represent - // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(TfLiteContext* context, - struct TfLiteDelegate* delegate); - - // Copy the data from delegate buffer handle into raw memory of the given - // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as - // long as it follows the rules for kTfLiteDynamic tensors, in which case this - // cannot be null. - TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, - struct TfLiteDelegate* delegate, + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void *data_; + + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext *context, + struct TfLiteDelegate *delegate); + + // Copy the data from delegate buffer handle into raw memory of the given + // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as + // long as it follows the rules for kTfLiteDynamic tensors, in which case this + // cannot be null. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext *context, + struct TfLiteDelegate *delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor *tensor); + + // Copy the data from raw memory of the given 'tensor' to delegate buffer + // handle. This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext *context, + struct TfLiteDelegate *delegate, TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor); - - // Copy the data from raw memory of the given 'tensor' to delegate buffer - // handle. This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, - struct TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor); - - // Free the Delegate Buffer Handle. Note: This only frees the handle, but - // this doesn't release the underlying resource (e.g. textures). The - // resources are either owned by application layer or the delegate. - // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(TfLiteContext* context, - struct TfLiteDelegate* delegate, - TfLiteBufferHandle* handle); - - // Bitmask flags. See the comments in `TfLiteDelegateFlags`. - int64_t flags; + TfLiteTensor *tensor); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteContext *context, + struct TfLiteDelegate *delegate, + TfLiteBufferHandle *handle); + + // Bitmask flags. See the comments in `TfLiteDelegateFlags`. + int64_t flags; } TfLiteDelegate; // Build a 'null' delegate, with all the fields properly set to their default @@ -825,4 +825,4 @@ TfLiteDelegate TfLiteDelegateCreate(); #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // ION_BB_DNN_TENSORFLOWLITE_TYPES_H +#endif // ION_BB_DNN_TENSORFLOWLITE_TYPES_H diff --git a/src/bb/fpga/bb.h b/src/bb/fpga/bb.h index d1398018..095e8895 100644 --- a/src/bb/fpga/bb.h +++ b/src/bb/fpga/bb.h @@ -118,7 +118,7 @@ class BayerOffset : public BuildingBlock { GeneratorParam gc_mandatory{"gc_mandatory", "width,height"}; GeneratorParam gc_strategy{"gc_strategy", "inlinable"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; GeneratorParam width{"width", 0}; GeneratorParam height{"height", 0}; @@ -175,7 +175,7 @@ class BayerWhiteBalance : public BuildingBlock { GeneratorParam gc_mandatory{"gc_mandatory", "width,height"}; GeneratorParam gc_strategy{"gc_strategy", "inlinable"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; GeneratorParam width{"width", 0}; GeneratorParam height{"height", 0}; @@ -257,7 +257,7 @@ class BayerDemosaicSimple : public BuildingBlock { GeneratorParam gc_mandatory{"gc_mandatory", "width,height"}; GeneratorParam gc_strategy{"gc_strategy", "inlinable"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; GeneratorParam width{"width", 0}; GeneratorParam height{"height", 0}; @@ -427,7 +427,7 @@ class LensShadingCorrectionLinear : public BuildingBlock gc_mandatory{"gc_mandatory", "width,height"}; GeneratorParam gc_strategy{"gc_strategy", "inlinable"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; // Max 16bit GeneratorParam width{"width", 0, 0, 65535}; @@ -488,7 +488,7 @@ class CalcLuminance : public BuildingBlock { GeneratorParam gc_mandatory{"gc_mandatory", "width,height"}; GeneratorParam gc_strategy{"gc_strategy", "inlinable"}; - //GeneratorParam luminance_method { "luminance_method", Luminance::Method::SimpleY, Luminance::enum_map }; + // GeneratorParam luminance_method { "luminance_method", Luminance::Method::SimpleY, Luminance::enum_map }; GeneratorParam luminance_method{"luminance_method", 2, 0, 2}; GeneratorParam width{"width", 0}; GeneratorParam height{"height", 0}; @@ -742,7 +742,7 @@ class SimpleISP : public BuildingBlock { GeneratorParam gc_prefix{"gc_prefix", ""}; GeneratorParam gc_required_features{"gc_required_features", "vivado_hls"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; // Max 16bit GeneratorParam width{"width", 0, 0, 65535}; @@ -850,7 +850,7 @@ class SimpleISPWithUnsharpMask : public BuildingBlock GeneratorParam gc_prefix{"gc_prefix", ""}; GeneratorParam gc_required_features{"gc_required_features", "vivado_hls"}; - //GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; + // GeneratorParam bayer_pattern { "bayer_pattern", BayerMap::Pattern::RGGB, BayerMap::enum_map }; GeneratorParam bayer_pattern{"bayer_pattern", 0, 0, 3}; // Max 16bit GeneratorParam width{"width", 0, 0, 65535}; diff --git a/src/bb/image-io/bb.h b/src/bb/image-io/bb.h index 0d3bc7f8..c030868d 100644 --- a/src/bb/image-io/bb.h +++ b/src/bb/image-io/bb.h @@ -61,8 +61,7 @@ const int BayerMap::bayer_map[4][4]{ }; #ifdef __linux__ -uint32_t make_pixel_format(BayerMap::Pattern bayer_pattern, int32_t bit_width) -{ +uint32_t make_pixel_format(BayerMap::Pattern bayer_pattern, int32_t bit_width) { uint32_t pix_format; switch (bit_width * 10 + static_cast(static_cast(bayer_pattern))) { case 80: // RGGB 8bit @@ -110,7 +109,6 @@ uint32_t make_pixel_format(BayerMap::Pattern bayer_pattern, int32_t bit_width) int instance_id = 0; - class Camera : public ion::BuildingBlock { public: BuildingBlockParam gc_title{"gc_title", "USBCamera"}; @@ -180,22 +178,17 @@ class Camera2 : public ion::BuildingBlock { BuildingBlockParam url0{"url0", ""}; BuildingBlockParam url1{"url1", ""}; - - Output output0{"output0", Halide::type_of(), 3}; Output output1{"output1", Halide::type_of(), 3}; - void generate() { using namespace Halide; - - for (int i =0; i < num_devices; i++){ + for (int i = 0; i < num_devices; i++) { std::string url_str; - if(i == 0){ + if (i == 0) { url_str = url0; - } - else{ + } else { url_str = url1; } @@ -221,23 +214,17 @@ class Camera2 : public ion::BuildingBlock { Expr g = saturating_cast(yv - cast(0.344f) * (uv - f128) - (cast(0.714f) * (vv - f128))); Expr b = saturating_cast(yv + cast(1.773f) * (uv - f128)); - - - Func f(static_cast(gc_prefix) + "output" + std::to_string(i)); f(x, y, c) = mux(c, {r, g, b}); - - if (i ==0) + if (i == 0) output0 = f; else output1 = f; } - } }; - class CameraN : public ion::BuildingBlock { public: BuildingBlockParam num_devices{"num_devices", 2}; @@ -257,33 +244,27 @@ class CameraN : public ion::BuildingBlock { Output output{"output", Halide::type_of(), 3}; - void generate() { std::stringstream urls_stream(urls); std::string url; std::vector url_list; - while(std::getline(urls_stream, url, ';')) - { + while (std::getline(urls_stream, url, ';')) { url_list.push_back(url); } - using namespace Halide; output.resize(num_devices); - for (int i =0; i < num_devices; i++){ + for (int i = 0; i < num_devices; i++) { std::string url_str; - if (url_list.size()!=0){ + if (url_list.size() != 0) { url_str = url_list[i]; - } - else{ + } else { url_str = ""; } - - Halide::Buffer url_buf(url_str.size() + 1); url_buf.fill(0); std::memcpy(url_buf.data(), url_str.c_str(), url_str.size()); @@ -306,13 +287,11 @@ class CameraN : public ion::BuildingBlock { Expr g = saturating_cast(yv - cast(0.344f) * (uv - f128) - (cast(0.714f) * (vv - f128))); Expr b = saturating_cast(yv + cast(1.773f) * (uv - f128)); - Func f(static_cast(gc_prefix) + "output" + std::to_string(i)); f(x, y, c) = mux(c, {r, g, b}); output[i](_) = f(_); } - } }; @@ -353,8 +332,7 @@ class IMX219 : public ion::BuildingBlock { url_buf, 0.4f, 0.5f, 0.3125f, 0.0625f, - 10, 6 - }; + 10, 6}; Func v4l2_imx219(static_cast(gc_prefix) + "output"); v4l2_imx219.define_extern("ion_bb_image_io_v4l2", params, type_of(), 2); v4l2_imx219.compute_root(); @@ -398,7 +376,6 @@ class D435 : public ion::BuildingBlock { } }; - class GenericV4L2Bayer : public ion::BuildingBlock { public: BuildingBlockParam gc_title{"gc_title", "GenericV4L2Bayer"}; @@ -437,8 +414,7 @@ class GenericV4L2Bayer : public ion::BuildingBlock { url_buf, 1.f, 1.f, 1.f, 0.f, - cast(bit_width), 16 - bit_width - }; + cast(bit_width), 16 - bit_width}; Func v4l2(static_cast(gc_prefix) + "output"); v4l2.define_extern("ion_bb_image_io_v4l2", params, type_of(), 2); v4l2.compute_root(); @@ -489,8 +465,7 @@ class CameraSimulation : public ion::BuildingBlock { url_buf, cast(gain_r), cast(gain_g), cast(gain_b), cast(offset), - cast(bit_width), cast(bit_shift) - }; + cast(bit_width), cast(bit_shift)}; Func camera(static_cast(gc_prefix) + "output"); camera.define_extern("ion_bb_image_io_v4l2", params, type_of(), 2); camera.compute_root(); @@ -500,7 +475,6 @@ class CameraSimulation : public ion::BuildingBlock { }; #endif - class GUIDisplay : public ion::BuildingBlock { public: BuildingBlockParam gc_title{"gc_title", "GUI Display"}; @@ -661,7 +635,6 @@ class ColorDataLoader : public ion::BuildingBlock { } }; - class ImageSaver : public ion::BuildingBlock { public: BuildingBlockParam gc_title{"gc_title", "Image Saver"}; @@ -710,17 +683,16 @@ class ImageSaver : public ion::BuildingBlock { template class U3VCamera1 : public ion::BuildingBlock> { public: - BuildingBlockParam frame_sync{"frame_sync", false}; BuildingBlockParam gain_key_ptr{"gain_key", "Gain"}; BuildingBlockParam exposure_key_ptr{"exposure_key", "Exposure"}; BuildingBlockParam realtime_display_mode{"realtime_display_mode", false}; - Input gain0{ "gain0" }; - Input exposure0{ "exposure0" }; + Input gain0{"gain0"}; + Input exposure0{"exposure0"}; - Output output0{ "output0", Halide::type_of(), D}; - Output frame_count{ "frame_count", Halide::type_of(), 1 }; + Output output0{"output0", Halide::type_of(), D}; + Output frame_count{"frame_count", Halide::type_of(), 1}; void generate() { using namespace Halide; @@ -740,10 +712,9 @@ class U3VCamera1 : public ion::BuildingBlock> { std::memcpy(exposure_key_buf.data(), exposure_key.c_str(), exposure_key.size()); std::vector params{ - static_cast(frame_sync), static_cast(realtime_display_mode), - gain0, exposure0, - id_buf, gain_key_buf, exposure_key_buf - }; + static_cast(frame_sync), static_cast(realtime_display_mode), + gain0, exposure0, + id_buf, gain_key_buf, exposure_key_buf}; camera1.define_extern("ion_bb_image_io_u3v_camera1", params, Halide::type_of(), D); camera1.compute_root(); output0(_) = camera1(_); @@ -752,13 +723,12 @@ class U3VCamera1 : public ion::BuildingBlock> { Func camera1_frame_count; { Buffer id_buf = this->get_id(); - camera1_frame_count.define_extern("ion_bb_image_io_u3v_camera1_frame_count",{camera1, 1, static_cast(frame_sync), static_cast(realtime_display_mode), id_buf}, type_of(), 1); + camera1_frame_count.define_extern("ion_bb_image_io_u3v_camera1_frame_count", {camera1, 1, static_cast(frame_sync), static_cast(realtime_display_mode), id_buf}, type_of(), 1); camera1_frame_count.compute_root(); frame_count(_) = camera1_frame_count(_); } this->register_disposer("u3v_dispose"); - } }; @@ -769,20 +739,19 @@ using U3VCamera1_U16x2 = U3VCamera1; template class U3VCamera2 : public ion::BuildingBlock> { public: - BuildingBlockParam frame_sync{"frame_sync", false}; BuildingBlockParam gain_key_ptr{"gain_key", "Gain"}; BuildingBlockParam exposure_key_ptr{"exposure_key", "Exposure"}; BuildingBlockParam realtime_display_mode{"realtime_display_mode", false}; - Input gain0{ "gain0" }; - Input gain1{ "gain1" }; - Input exposure0{ "exposure0" }; - Input exposure1{ "exposure1" }; + Input gain0{"gain0"}; + Input gain1{"gain1"}; + Input exposure0{"exposure0"}; + Input exposure1{"exposure1"}; - Output output0{ "output0", Halide::type_of(), D}; - Output output1{ "output1", Halide::type_of(), D}; - Output frame_count{ "frame_count", Halide::type_of(), 1 }; + Output output0{"output0", Halide::type_of(), D}; + Output output1{"output1", Halide::type_of(), D}; + Output frame_count{"frame_count", Halide::type_of(), 1}; void generate() { using namespace Halide; @@ -804,17 +773,17 @@ class U3VCamera2 : public ion::BuildingBlock> { std::vector params{ static_cast(frame_sync), static_cast(realtime_display_mode), gain0, gain1, exposure0, exposure1, - id_buf, gain_key_buf, exposure_key_buf - }; - camera2.define_extern("ion_bb_image_io_u3v_camera2", params, { Halide::type_of(), Halide::type_of() }, D); + id_buf, gain_key_buf, exposure_key_buf}; + camera2.define_extern("ion_bb_image_io_u3v_camera2", params, {Halide::type_of(), Halide::type_of()}, D); camera2.compute_root(); output0(_) = camera2(_)[0]; output1(_) = camera2(_)[1]; } - Func camera2_frame_count;{ + Func camera2_frame_count; + { Buffer id_buf = this->get_id(); - camera2_frame_count.define_extern("ion_bb_image_io_u3v_camera2_frame_count", { camera2, 2, static_cast(frame_sync), static_cast(realtime_display_mode), id_buf}, type_of(), 1); + camera2_frame_count.define_extern("ion_bb_image_io_u3v_camera2_frame_count", {camera2, 2, static_cast(frame_sync), static_cast(realtime_display_mode), id_buf}, type_of(), 1); camera2_frame_count.compute_root(); frame_count(_) = camera2_frame_count(_); } @@ -837,9 +806,9 @@ class U3VCameraN : public ion::BuildingBlock> { BuildingBlockParam gain_key_ptr{"gain_key", "Gain"}; BuildingBlockParam exposure_key_ptr{"exposure_key", "Exposure"}; - Output output{ "output", Halide::type_of(), D}; - Output device_info{ "device_info", Halide::type_of(), 1}; - Output frame_count{ "frame_count", Halide::type_of(), 1 }; + Output output{"output", Halide::type_of(), D}; + Output device_info{"device_info", Halide::type_of(), 1}; + Output frame_count{"frame_count", Halide::type_of(), 1}; std::vector *> gain; std::vector *> exposure; @@ -852,7 +821,7 @@ class U3VCameraN : public ion::BuildingBlock> { void configure() { if (enable_control) { - for (auto i=0; i("gain_" + std::to_string(i))); exposure.push_back(Halide::Internal::GeneratorBase::add_input("exposure_" + std::to_string(i))); } @@ -884,13 +853,12 @@ class U3VCameraN : public ion::BuildingBlock> { std::vector params{ id_buf, static_cast(force_sim_mode), - static_cast(width), static_cast(height),static_cast(fps), + static_cast(width), static_cast(height), static_cast(fps), static_cast(frame_sync), static_cast(realtime_display_mode), static_cast(enable_control), - gain_key_buf, exposure_key_buf, pixel_format_buf - }; + gain_key_buf, exposure_key_buf, pixel_format_buf}; - for (int i = 0; i> { output.resize(num_devices); cameraN.define_extern("ion_bb_image_io_u3v_multiple_camera" + std::to_string(num_devices), params, std::vector(num_devices, Halide::type_of()), D); cameraN.compute_root(); - if (num_devices == 1){ + if (num_devices == 1) { output[0](_) = cameraN(_); } else { - for (int i = 0; i> { static_cast(force_sim_mode), static_cast(width), static_cast(height), static_cast(fps), static_cast(frame_sync), static_cast(realtime_display_mode), - pixel_format_buf - }; + pixel_format_buf}; device_info.resize(num_devices); std::vector output_type; @@ -939,9 +906,9 @@ class U3VCameraN : public ion::BuildingBlock> { } u3v_device_info.define_extern("ion_bb_image_io_u3v_device_info" + std::to_string(device_info.size()), params, output_type, 1); u3v_device_info.compute_root(); - if (device_info.size() == 1){ + if (device_info.size() == 1) { device_info[0](_) = u3v_device_info(_); - }else{ + } else { for (int i = 0; i < device_info.size(); i++) { device_info[i](_) = u3v_device_info(_)[i]; } @@ -961,8 +928,7 @@ class U3VCameraN : public ion::BuildingBlock> { static_cast(force_sim_mode), static_cast(width), static_cast(height), static_cast(fps), static_cast(frame_sync), static_cast(realtime_display_mode), - pixel_format_buf - }; + pixel_format_buf}; frame_count.resize(num_devices); std::vector output_type; @@ -971,9 +937,9 @@ class U3VCameraN : public ion::BuildingBlock> { } cameraN_fc.define_extern("ion_bb_image_io_u3v_multiple_camera_frame_count" + std::to_string(output.size()), params, output_type, 1); cameraN_fc.compute_root(); - if (frame_count.size() == 1){ + if (frame_count.size() == 1) { frame_count[0](_) = cameraN_fc(_); - }else{ + } else { for (int i = 0; i < device_info.size(); i++) { frame_count[i](_) = cameraN_fc(_)[i]; } @@ -981,7 +947,6 @@ class U3VCameraN : public ion::BuildingBlock> { } this->register_disposer("u3v_dispose"); } - }; using U3VCameraN_U8x3 = U3VCameraN; @@ -990,7 +955,7 @@ using U3VCameraN_U16x2 = U3VCameraN; class U3VCameraGenDC : public ion::BuildingBlock { public: - BuildingBlockParam num_devices{"num_devices", 2}; // NOTE: num_devices refers to sensor count not usb device count + BuildingBlockParam num_devices{"num_devices", 2}; // NOTE: num_devices refers to sensor count not usb device count BuildingBlockParam frame_sync{"frame_sync", false}; BuildingBlockParam realtime_display_mode{"realtime_display_mode", false}; @@ -998,8 +963,8 @@ class U3VCameraGenDC : public ion::BuildingBlock { BuildingBlockParam gain_key_ptr{"gain_key", "Gain"}; BuildingBlockParam exposure_key_ptr{"exposure_key", "Exposure"}; - Output gendc{ "gendc", Halide::type_of(), 1}; - Output device_info{ "device_info", Halide::type_of(), 1}; + Output gendc{"gendc", Halide::type_of(), 1}; + Output device_info{"device_info", Halide::type_of(), 1}; std::vector *> gain; std::vector *> exposure; @@ -1012,7 +977,7 @@ class U3VCameraGenDC : public ion::BuildingBlock { void configure() { if (enable_control) { - for (auto i=0; i("gain_" + std::to_string(i))); exposure.push_back(Halide::Internal::GeneratorBase::add_input("exposure_" + std::to_string(i))); } @@ -1024,7 +989,7 @@ class U3VCameraGenDC : public ion::BuildingBlock { Func u3v_gendc("u3v_gendc"); { - Buffer id_buf = this->get_id(); + Buffer id_buf = this->get_id(); const std::string gain_key(gain_key_ptr); Buffer gain_key_buf(static_cast(gain_key.size() + 1)); @@ -1044,13 +1009,12 @@ class U3VCameraGenDC : public ion::BuildingBlock { std::vector params{ id_buf, static_cast(force_sim_mode), - static_cast(width), static_cast(height),static_cast(fps), + static_cast(width), static_cast(height), static_cast(fps), static_cast(frame_sync), static_cast(realtime_display_mode), static_cast(enable_control), - gain_key_buf, exposure_key_buf, pixel_format_buf - }; + gain_key_buf, exposure_key_buf, pixel_format_buf}; - for (int i = 0; i { } u3v_gendc.define_extern("ion_bb_image_io_u3v_gendc_camera" + std::to_string(gendc.size()), params, output_type, 1); u3v_gendc.compute_root(); - if (gendc.size() == 1){ + if (gendc.size() == 1) { gendc[0](_) = u3v_gendc(_); - }else{ + } else { for (int i = 0; i < gendc.size(); i++) { gendc[i](_) = u3v_gendc(_)[i]; } @@ -1093,8 +1057,7 @@ class U3VCameraGenDC : public ion::BuildingBlock { static_cast(force_sim_mode), static_cast(width), static_cast(height), static_cast(fps), static_cast(frame_sync), static_cast(realtime_display_mode), - pixel_format_buf - }; + pixel_format_buf}; device_info.resize(num_devices); std::vector output_type; @@ -1103,9 +1066,9 @@ class U3VCameraGenDC : public ion::BuildingBlock { } u3v_device_info.define_extern("ion_bb_image_io_u3v_device_info" + std::to_string(device_info.size()), params, output_type, 1); u3v_device_info.compute_root(); - if (device_info.size() == 1){ + if (device_info.size() == 1) { device_info[0](_) = u3v_device_info(_); - }else{ + } else { for (int i = 0; i < device_info.size(); i++) { device_info[i](_) = u3v_device_info(_)[i]; } @@ -1116,19 +1079,17 @@ class U3VCameraGenDC : public ion::BuildingBlock { } }; - - template class BinarySaver : public ion::BuildingBlock> { public: - BuildingBlockParam output_directory_ptr{ "output_directory", "." }; + BuildingBlockParam output_directory_ptr{"output_directory", "."}; BuildingBlockParam prefix_ptr{"prefix", "raw-"}; Input input_images{"input", Halide::type_of(), D}; - Input input_deviceinfo{ "input_deviceinfo", Halide::type_of(), 1 }; - Input frame_count{ "frame_count", Halide::type_of(), 1 }; - Input width{ "width" }; - Input height{ "height" }; + Input input_deviceinfo{"input_deviceinfo", Halide::type_of(), 1}; + Input frame_count{"frame_count", Halide::type_of(), 1}; + Input width{"width"}; + Input height{"height"}; Output output{"output"}; @@ -1162,7 +1123,7 @@ class BinarySaver : public ion::BuildingBlock> { deviceinfo(_) = input_deviceinfo(_); deviceinfo.compute_root(); - std::vector params = {id_buf, image, deviceinfo, fc, width, height, dim, byte_depth, output_directory_buf, prefix_buf }; + std::vector params = {id_buf, image, deviceinfo, fc, width, height, dim, byte_depth, output_directory_buf, prefix_buf}; Func ion_bb_image_io_binary_image_saver; ion_bb_image_io_binary_image_saver.define_extern("ion_bb_image_io_binary_image_saver", params, Int(32), 0); ion_bb_image_io_binary_image_saver.compute_root(); @@ -1172,24 +1133,22 @@ class BinarySaver : public ion::BuildingBlock> { } }; - using BinarySaver_U8x3 = BinarySaver; using BinarySaver_U8x2 = BinarySaver; using BinarySaver_U16x2 = BinarySaver; class BinaryGenDCSaver : public ion::BuildingBlock { public: - BuildingBlockParam output_directory_ptr{ "output_directory", "." }; + BuildingBlockParam output_directory_ptr{"output_directory", "."}; BuildingBlockParam prefix_ptr{"prefix", "raw-"}; - Input input_gendc{ "input_gendc", Halide::type_of(), 1 }; - Input input_deviceinfo{ "input_deviceinfo", Halide::type_of(), 1 }; - + Input input_gendc{"input_gendc", Halide::type_of(), 1}; + Input input_deviceinfo{"input_deviceinfo", Halide::type_of(), 1}; - Input payloadsize{ "payloadsize" }; + Input payloadsize{"payloadsize"}; - Output output{ "output" }; + Output output{"output"}; void generate() { using namespace Halide; @@ -1213,7 +1172,7 @@ class BinaryGenDCSaver : public ion::BuildingBlock { deviceinfo(_) = input_deviceinfo(_); deviceinfo.compute_root(); - std::vector params = { id_buf, gendc, deviceinfo, payloadsize, output_directory_buf, prefix_buf }; + std::vector params = {id_buf, gendc, deviceinfo, payloadsize, output_directory_buf, prefix_buf}; Func image_io_binary_gendc_saver; image_io_binary_gendc_saver.define_extern("ion_bb_image_io_binary_gendc_saver", params, Int(32), 0); image_io_binary_gendc_saver.compute_root(); @@ -1225,13 +1184,13 @@ class BinaryGenDCSaver : public ion::BuildingBlock { class BinaryLoader : public ion::BuildingBlock { public: - BuildingBlockParam output_directory_ptr{ "output_directory_ptr", "" }; - Input width{ "width", 0 }; - Input height{ "height", 0 }; - Output output0{ "output0", UInt(16), 2 }; - Output output1{ "output1", UInt(16), 2 }; - Output finished{ "finished", UInt(1), 1}; - Output bin_idx{ "bin_idx", UInt(32), 1 }; + BuildingBlockParam output_directory_ptr{"output_directory_ptr", ""}; + Input width{"width", 0}; + Input height{"height", 0}; + Output output0{"output0", UInt(16), 2}; + Output output1{"output1", UInt(16), 2}; + Output finished{"finished", UInt(1), 1}; + Output bin_idx{"bin_idx", UInt(32), 1}; void generate() { using namespace Halide; @@ -1246,18 +1205,17 @@ class BinaryLoader : public ion::BuildingBlock { output_directory_buf.fill(0); std::memcpy(output_directory_buf.data(), output_directory.c_str(), output_directory.size()); - std::vector params = { session_id_buf, width, height, output_directory_buf }; + std::vector params = {session_id_buf, width, height, output_directory_buf}; Func binaryloader; - binaryloader.define_extern("binaryloader", params, { UInt(16), UInt(16) }, 2); + binaryloader.define_extern("binaryloader", params, {UInt(16), UInt(16)}, 2); binaryloader.compute_root(); output0(_) = binaryloader(_)[0]; output1(_) = binaryloader(_)[1]; - Func binaryloader_finished; binaryloader_finished.define_extern("binaryloader_finished", - { binaryloader, session_id_buf, width, height, output_directory_buf }, - { type_of(), UInt(32)}, 1); + {binaryloader, session_id_buf, width, height, output_directory_buf}, + {type_of(), UInt(32)}, 1); binaryloader_finished.compute_root(); finished(_) = binaryloader_finished(_)[0]; bin_idx(_) = binaryloader_finished(_)[1]; @@ -1281,7 +1239,6 @@ ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::Camera2, image_io_camera2); ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::CameraN, image_io_cameraN); #endif - ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::ColorDataLoader, image_io_color_data_loader); ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::GrayscaleDataLoader, image_io_grayscale_data_loader); @@ -1310,7 +1267,7 @@ ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::BinaryLoader, image_io_binaryload ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::BinaryGenDCSaver, image_io_binary_gendc_saver); -//backward compatability +// backward compatability ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::U3VCamera1_U8x3, u3v_camera1_u8x3); ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::U3VCamera1_U16x2, u3v_camera1_u16x2); ION_REGISTER_BUILDING_BLOCK(ion::bb::image_io::U3VCamera2_U8x3, u3v_camera2_u8x3); diff --git a/src/bb/image-io/gendc_separator/ComponentHeader.h b/src/bb/image-io/gendc_separator/ComponentHeader.h index 248819c2..7944348a 100644 --- a/src/bb/image-io/gendc_separator/ComponentHeader.h +++ b/src/bb/image-io/gendc_separator/ComponentHeader.h @@ -3,14 +3,15 @@ #include "PartHeader.h" -class ComponentHeader : public Header{ +class ComponentHeader : public Header { public: - ComponentHeader(){} + ComponentHeader() { + } - ComponentHeader(char* header_info, size_t offset = 0){ + ComponentHeader(char *header_info, size_t offset = 0) { int16_t header_type; offset += Read(header_info, offset, header_type); - if (header_type != HeaderType_){ + if (header_type != HeaderType_) { std::cerr << "wrong header type in component header" << std::endl; } @@ -28,21 +29,20 @@ class ComponentHeader : public Header{ offset += sizeof(Reserved2_); offset += Read(header_info, offset, PartCount_); - for (int i = 0; i < PartCount_; ++i){ + for (int i = 0; i < PartCount_; ++i) { int64_t single_part_offset; offset += Read(header_info, offset, single_part_offset); PartOffset_.push_back(single_part_offset); } - for (int64_t & po : PartOffset_){ + for (int64_t &po : PartOffset_) { partheader_.push_back(PartHeader(header_info, po)); } - } - ComponentHeader& operator=(const ComponentHeader& src) { + ComponentHeader &operator=(const ComponentHeader &src) { partheader_ = src.partheader_; - + // HeaderType_ = 0x2000; Flags_ = src.Flags_; HeaderSize_ = src.HeaderSize_; @@ -61,45 +61,44 @@ class ComponentHeader : public Header{ return *this; } - int64_t getGroupID(){ + int64_t getGroupID() { return GroupId_; } - int64_t getTypeId(){ + int64_t getTypeId() { return TypeId_; } - int16_t getSourceId(){ + int16_t getSourceId() { return SourceId_; } - int16_t getPartCount(){ + int16_t getPartCount() { return PartCount_; } - - size_t GenerateDescriptor(char* ptr, size_t offset=0){ + size_t GenerateDescriptor(char *ptr, size_t offset = 0) { offset = GenerateHeader(ptr, offset); - for (PartHeader & ph : partheader_){ + for (PartHeader &ph : partheader_) { offset = ph.GenerateDescriptor(ptr, offset); } return offset; } - bool isComponentValid(){ + bool isComponentValid() { return Flags_ == 0; } - int32_t getFirstAvailableDataOffset(bool image){ + int32_t getFirstAvailableDataOffset(bool image) { // returns the part header index where // - component is valid // - part header type is 0x4200 (GDC_2D) if image is true int32_t jth_part = 0; - for (PartHeader &ph : partheader_){ - if (image && ph.isData2DImage()){ + for (PartHeader &ph : partheader_) { + if (image && ph.isData2DImage()) { return jth_part; - }else if (!image && !ph.isData2DImage()){ + } else if (!image && !ph.isData2DImage()) { return jth_part; } ++jth_part; @@ -107,19 +106,19 @@ class ComponentHeader : public Header{ return -1; } - int64_t getDataOffset(int32_t jth_part){ + int64_t getDataOffset(int32_t jth_part) { return partheader_.at(jth_part).getDataOffset(); } - int64_t getDataSize(int32_t jth_part){ + int64_t getDataSize(int32_t jth_part) { return partheader_.at(jth_part).getDataSize(); } - int32_t getOffsetFromTypeSpecific(int32_t jth_part, int32_t kth_typespecific, int32_t typespecific_offset = 0){ + int32_t getOffsetFromTypeSpecific(int32_t jth_part, int32_t kth_typespecific, int32_t typespecific_offset = 0) { return PartOffset_.at(jth_part) + partheader_.at(jth_part).getOffsetFromTypeSpecific(kth_typespecific, typespecific_offset); } - void DisplayHeaderInfo(){ + void DisplayHeaderInfo() { int total_size = 0; std::cout << "\nCOMPONENT HEADER" << std::endl; total_size += DisplayItemInfo("HeaderType_", HeaderType_, 2, true); @@ -139,13 +138,13 @@ class ComponentHeader : public Header{ total_size += DisplayContainer("PartOffset_", PartOffset_, 2); - for (PartHeader &ph : partheader_){ + for (PartHeader &ph : partheader_) { ph.DisplayHeaderInfo(); } } private: - size_t GenerateHeader(char* ptr, size_t offset=0){ + size_t GenerateHeader(char *ptr, size_t offset = 0) { // modify the order/items only when the structure is changed. // when you change this, don't forget to change copy constructor. size_t cpy_offset = offset; @@ -163,17 +162,17 @@ class ComponentHeader : public Header{ offset += Write(ptr, offset, Format_); offset += Write(ptr, offset, Reserved2_); offset += Write(ptr, offset, PartCount_); - + offset += WriteContainer(ptr, offset, PartOffset_); - if ((offset - cpy_offset) != HeaderSize_){ + if ((offset - cpy_offset) != HeaderSize_) { std::cerr << "Component header size is wrong" << HeaderSize_ << " != " << offset - cpy_offset << std::endl; } return offset; } std::vector partheader_; - + const int16_t HeaderType_ = 0x2000; int16_t Flags_; // int32_t HeaderSize_; @@ -191,5 +190,4 @@ class ComponentHeader : public Header{ std::vector PartOffset_; }; - #endif /*COMPONENTHEADER_H*/ \ No newline at end of file diff --git a/src/bb/image-io/gendc_separator/ContainerHeader.h b/src/bb/image-io/gendc_separator/ContainerHeader.h index f50ff8c8..73688eee 100644 --- a/src/bb/image-io/gendc_separator/ContainerHeader.h +++ b/src/bb/image-io/gendc_separator/ContainerHeader.h @@ -3,23 +3,23 @@ #include "ComponentHeader.h" -class ContainerHeader : public Header{ +class ContainerHeader : public Header { public: + ContainerHeader() { + } - ContainerHeader(){} - - ContainerHeader(char* descriptor){ + ContainerHeader(char *descriptor) { size_t offset = 0; int32_t signature; int16_t header_type; // check if the container is GenDC offset += Read(descriptor, offset, signature); - if (signature != Signature_){ + if (signature != Signature_) { std::cerr << "This ptr does NOT hace GenDC Signature" << std::endl; } - for (int i = 0 ; i < Version_.size(); i++){ + for (int i = 0; i < Version_.size(); i++) { int8_t v; offset += Read(descriptor, offset, v); Version_.at(i) = v; @@ -27,7 +27,7 @@ class ContainerHeader : public Header{ offset += sizeof(Reserved_); offset += Read(descriptor, offset, header_type); - if (header_type != HeaderType_){ + if (header_type != HeaderType_) { std::cerr << "wrong header type in container header" << std::endl; } offset += Read(descriptor, offset, Flags_); @@ -40,18 +40,18 @@ class ContainerHeader : public Header{ offset += Read(descriptor, offset, DescriptorSize_); offset += Read(descriptor, offset, ComponentCount_); - for (int i = 0; i < ComponentCount_; ++i){ + for (int i = 0; i < ComponentCount_; ++i) { int64_t single_component_offset; offset += Read(descriptor, offset, single_component_offset); ComponentOffset_.push_back(single_component_offset); } - for (int64_t & co : ComponentOffset_){ + for (int64_t &co : ComponentOffset_) { component_header_.push_back(ComponentHeader(descriptor, co)); } } - ContainerHeader& operator=(const ContainerHeader& src) { + ContainerHeader &operator=(const ContainerHeader &src) { component_header_ = src.component_header_; // Signature_ = 0x43444E47; @@ -71,15 +71,15 @@ class ContainerHeader : public Header{ return *this; } - ComponentHeader getComponentByIndex(int ith_component_index){ + ComponentHeader getComponentByIndex(int ith_component_index) { return component_header_[ith_component_index]; } - int32_t getFirstComponentIndexByTypeID(int64_t type_id){ + int32_t getFirstComponentIndexByTypeID(int64_t type_id) { int cnt = 0; - for (ComponentHeader &ch : component_header_){ - if (ch.isComponentValid()){ - if (type_id==ch.getTypeId()){ + for (ComponentHeader &ch : component_header_) { + if (ch.isComponentValid()) { + if (type_id == ch.getTypeId()) { return cnt; } } @@ -88,34 +88,33 @@ class ContainerHeader : public Header{ return -1; } - - int32_t getDescriptorSize(){ + int32_t getDescriptorSize() { return DescriptorSize_; } - size_t GenerateDescriptor(char* ptr){ + size_t GenerateDescriptor(char *ptr) { size_t offset = 0; offset = GenerateHeader(ptr); - for ( ComponentHeader &ch : component_header_){ + for (ComponentHeader &ch : component_header_) { offset = ch.GenerateDescriptor(ptr, offset); } - if ( offset != DescriptorSize_){ + if (offset != DescriptorSize_) { std::cerr << "Descriptor size is wrong" << DescriptorSize_ << " != " << offset << std::endl; } return offset; } - std::tuple getFirstAvailableDataOffset(bool image){ + std::tuple getFirstAvailableDataOffset(bool image) { // returns the component and part header index where // - component is valid // - part header type is 0x4200 (GDC_2D) if image is true int32_t ith_comp = 0; - for (ComponentHeader &ch : component_header_){ - if (ch.isComponentValid()){ + for (ComponentHeader &ch : component_header_) { + if (ch.isComponentValid()) { int32_t jth_part = ch.getFirstAvailableDataOffset(image); - if (jth_part != -1){ + if (jth_part != -1) { return std::make_tuple(ith_comp, jth_part); } ++ith_comp; @@ -124,28 +123,28 @@ class ContainerHeader : public Header{ return std::make_tuple(-1, -1); } - int64_t getDataOffset(int32_t ith_component = 0, int32_t jth_part = 0){ - if (ith_component == 0 && jth_part == 0){ + int64_t getDataOffset(int32_t ith_component = 0, int32_t jth_part = 0) { + if (ith_component == 0 && jth_part == 0) { return DataOffset_; } return component_header_.at(ith_component).getDataOffset(jth_part); } - int64_t getDataSize(){ + int64_t getDataSize() { return DataSize_; } - int64_t getDataSize(int32_t ith_component = 0, int32_t jth_part = 0){ + int64_t getDataSize(int32_t ith_component = 0, int32_t jth_part = 0) { return component_header_.at(ith_component).getDataSize(jth_part); } int32_t getOffsetFromTypeSpecific(int32_t ith_component, int32_t jth_part, - int32_t kth_typespecific, int32_t typespecific_offset = 0){ + int32_t kth_typespecific, int32_t typespecific_offset = 0) { return component_header_.at(ith_component).getOffsetFromTypeSpecific(jth_part, kth_typespecific, typespecific_offset); } - void DisplayHeaderInfo(){ + void DisplayHeaderInfo() { int total_size = 0; std::cout << "\nCONTAINER HEADER" << std::endl; total_size += DisplayItemInfo("Signature_", Signature_, 1, true); @@ -163,13 +162,13 @@ class ContainerHeader : public Header{ total_size += DisplayContainer("ComponentOffset_", ComponentOffset_, 1); - for (ComponentHeader &ch : component_header_){ + for (ComponentHeader &ch : component_header_) { ch.DisplayHeaderInfo(); } } private: - size_t GenerateHeader(char* ptr){ + size_t GenerateHeader(char *ptr) { // modify the order/items only when the structure is changed. // when you change this, don't forget to change copy constructor. size_t offset = 0; @@ -187,7 +186,7 @@ class ContainerHeader : public Header{ offset += Write(ptr, offset, ComponentCount_); offset += WriteContainer(ptr, offset, ComponentOffset_); - if ( offset != HeaderSize_){ + if (offset != HeaderSize_) { std::cerr << "Container header size is wrong" << HeaderSize_ << " != " << offset << std::endl; } return offset; @@ -204,7 +203,7 @@ class ContainerHeader : public Header{ int16_t Flags_; // int32_t HeaderSize_; int64_t Id_; - int64_t VariableFields_; //including 6 Byte-wide Reserved + int64_t VariableFields_; // including 6 Byte-wide Reserved int64_t DataSize_; int64_t DataOffset_; int32_t DescriptorSize_; diff --git a/src/bb/image-io/gendc_separator/Descriptor.h b/src/bb/image-io/gendc_separator/Descriptor.h index ebc9e10d..c1a96909 100644 --- a/src/bb/image-io/gendc_separator/Descriptor.h +++ b/src/bb/image-io/gendc_separator/Descriptor.h @@ -37,7 +37,7 @@ #define GDC_1D 0x4100 #define GDC_2D 0x4200 -//format +// format #define Mono12 0x01100005 #define Data8 0x01080116 #define Data16 0x01100118 @@ -59,125 +59,122 @@ #define VERSION_OFFSET 4 namespace { - enum offset { - descriptor_size, - deta_size, - data_offset, - }; +enum offset { + descriptor_size, + deta_size, + data_offset, +}; } #define GENDC_V10 0x0100 // https://www.emva.org/wp-content/uploads/GenICam_GenDC_v1_1.pdf -std::map> offset_for_version = -{ - {GENDC_V10, std::array{48, 32, 40}}, +std::map> offset_for_version = + { + {GENDC_V10, std::array{48, 32, 40}}, }; - -enum display_lebel{ +enum display_lebel { default_display, container_header_display, component_header_display, part_header_display }; -std::string display_indent(int level=default_display){ - std::string ret=""; - for (int i = 0; i < level; ++i){ +std::string display_indent(int level = default_display) { + std::string ret = ""; + for (int i = 0; i < level; ++i) { ret += "\t"; } return ret; -} +} -class Header{ +class Header { public: - size_t getHeaderSize(){ + size_t getHeaderSize() { return HeaderSize_; } protected: - template - void DisplayItem(T item, bool hex_format){ - if(sizeof(item) == sizeof(char)){ + template + void DisplayItem(T item, bool hex_format) { + if (sizeof(item) == sizeof(char)) { DisplayItem(static_cast(item), hex_format); - }else{ + } else { std::cout << std::right << std::setw(DISPLAY_VALUE_WIDTH); - if (hex_format){ + if (hex_format) { std::cout << std::hex << "0x" << item << std::endl; - }else{ + } else { std::cout << std::dec << item << std::endl; } } - } - template - int DisplayItemInfo(std::string item_name, T item, int level=default_display, bool hex_format=false){ + template + int DisplayItemInfo(std::string item_name, T item, int level = default_display, bool hex_format = false) { std::string indent = display_indent(level); int sizeof_item = sizeof(item); std::cout << indent << std::right << std::setw(DISPLAY_ITEM_WIDTH) << item_name; - std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << sizeof_item << "):"; + std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << sizeof_item << "):"; DisplayItem(item, hex_format); return sizeof_item; } - template - int DisplayContainer(std::string container_name, const std::vector&container, int level=default_display, bool hex=false){ + template + int DisplayContainer(std::string container_name, const std::vector &container, int level = default_display, bool hex = false) { int total_size = 0; - if (container.size() > 0){ + if (container.size() > 0) { std::string key = container_name; - for(int i=0; i < container.size(); ++i){ + for (int i = 0; i < container.size(); ++i) { total_size += DisplayItemInfo(i > 0 ? "" : key, container.at(i), level, hex); } - }else{ + } else { std::cout << display_indent(level) << std::right << std::setw(DISPLAY_ITEM_WIDTH) << container_name; - std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << 0 << "):\n"; + std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << 0 << "):\n"; } return total_size; } - template - int DisplayContainer(std::string container_name, const std::array&container, int level=default_display, bool hex=false){ + template + int DisplayContainer(std::string container_name, const std::array &container, int level = default_display, bool hex = false) { int total_size = 0; - if (container.size() > 0){ + if (container.size() > 0) { std::string key = container_name; - for(int i=0; i < container.size(); ++i){ + for (int i = 0; i < container.size(); ++i) { total_size += DisplayItemInfo(i > 0 ? "" : key, container.at(i), level, hex); } - }else{ + } else { std::cout << display_indent(level) << std::right << std::setw(DISPLAY_ITEM_WIDTH) << container_name; - std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << 0 << "):\n"; + std::cout << std::right << std::setw(DISPLAY_SIZE_WIDTH) << " (" << 0 << "):\n"; } return total_size; } - - template - size_t Read(char* ptr, size_t offset, T& item){ - memcpy(&item, ptr+static_cast(offset), sizeof(item)); + template + size_t Read(char *ptr, size_t offset, T &item) { + memcpy(&item, ptr + static_cast(offset), sizeof(item)); return sizeof(item); } - template - size_t Write(char* ptr, size_t offset, T item){ - memcpy(ptr+static_cast(offset), &item, sizeof(item)); + template + size_t Write(char *ptr, size_t offset, T item) { + memcpy(ptr + static_cast(offset), &item, sizeof(item)); return sizeof(item); } - template - size_t WriteContainer(char* ptr, size_t offset, std::vector&container){ + template + size_t WriteContainer(char *ptr, size_t offset, std::vector &container) { size_t container_offset = 0; - for (T& item : container){ + for (T &item : container) { container_offset += Write(ptr, offset + container_offset, item); } return container_offset; } - template - size_t WriteContainer(char* ptr, size_t offset, std::array&container){ + template + size_t WriteContainer(char *ptr, size_t offset, std::array &container) { size_t container_offset = 0; - for (T& item : container){ + for (T &item : container) { container_offset += Write(ptr, offset + container_offset, item); } return container_offset; diff --git a/src/bb/image-io/gendc_separator/PartHeader.h b/src/bb/image-io/gendc_separator/PartHeader.h index aa0cab37..69b586c6 100644 --- a/src/bb/image-io/gendc_separator/PartHeader.h +++ b/src/bb/image-io/gendc_separator/PartHeader.h @@ -9,29 +9,30 @@ #include "Descriptor.h" -int getByteInFormat(int format){ - switch (format){ - case Mono12: - return 2; - case Data8: - return 1; - case Data16: - return 2; - case Data32: - return 4; - case Data32f: - return 4; - default: - throw std::invalid_argument("wrong format\n"); +int getByteInFormat(int format) { + switch (format) { + case Mono12: + return 2; + case Data8: + return 1; + case Data16: + return 2; + case Data32: + return 4; + case Data32f: + return 4; + default: + throw std::invalid_argument("wrong format\n"); } } // namespace { -class PartHeader : public Header{ -public: - PartHeader(){} +class PartHeader : public Header { +public: + PartHeader() { + } // constructor with existing header info - PartHeader(char* header_info, size_t offset = 0){ + PartHeader(char *header_info, size_t offset = 0) { size_t total_size = 0; offset += Read(header_info, offset, HeaderType_); @@ -47,63 +48,63 @@ class PartHeader : public Header{ // get number of typespecific fields from HeaderSize_ int num_typespecific = getNumTypeSpecific(HeaderSize_); - if (num_typespecific > 0){ + if (num_typespecific > 0) { offset += Read(header_info, offset, Dimension_[0]); offset += Read(header_info, offset, Dimension_[1]); } - if (num_typespecific > 1){ + if (num_typespecific > 1) { offset += Read(header_info, offset, Padding_[0]); offset += Read(header_info, offset, Padding_[1]); } - if (num_typespecific > 2){ + if (num_typespecific > 2) { offset += sizeof(InfoReserved_); int64_t typespecific_item; - for (int i = 0; i < num_typespecific - 2; ++i){ + for (int i = 0; i < num_typespecific - 2; ++i) { offset += Read(header_info, offset, typespecific_item); TypeSpecific_.push_back(typespecific_item); - } + } } } - PartHeader& operator=(const PartHeader& src) { + PartHeader &operator=(const PartHeader &src) { HeaderType_ = src.HeaderType_; - Flags_= src.Flags_; - HeaderSize_= src.HeaderSize_; - Format_= src.Format_; + Flags_ = src.Flags_; + HeaderSize_ = src.HeaderSize_; + Format_ = src.Format_; // Reserved_ = 0; - FlowId_= src.FlowId_; - FlowOffset_= src.FlowOffset_; - DataSize_= src.DataSize_; - DataOffset_= src.DataOffset_; - - Dimension_= src.Dimension_; - Padding_= src.Padding_; - TypeSpecific_= src.TypeSpecific_; + FlowId_ = src.FlowId_; + FlowOffset_ = src.FlowOffset_; + DataSize_ = src.DataSize_; + DataOffset_ = src.DataOffset_; + + Dimension_ = src.Dimension_; + Padding_ = src.Padding_; + TypeSpecific_ = src.TypeSpecific_; return *this; } - size_t GenerateDescriptor(char* ptr, size_t offset=0){ + size_t GenerateDescriptor(char *ptr, size_t offset = 0) { offset = GenerateHeader(ptr, offset); return offset; } - bool isData2DImage(){ + bool isData2DImage() { return HeaderType_ == 0x4200; } - int64_t getDataOffset(){ + int64_t getDataOffset() { return DataOffset_; } - int64_t getDataSize(){ + int64_t getDataSize() { return DataSize_; } - int32_t getOffsetFromTypeSpecific(int32_t kth_typespecific, int32_t typespecific_offset = 0){ + int32_t getOffsetFromTypeSpecific(int32_t kth_typespecific, int32_t typespecific_offset = 0) { return offset_for_version[GENDC_V10].at(2) + 8 * (kth_typespecific - 1) + typespecific_offset; } - void DisplayHeaderInfo(){ + void DisplayHeaderInfo() { int total_size = 0; std::cout << "\nPART HEADER" << std::endl; total_size += DisplayItemInfo("HeaderType_", HeaderType_, 3, true); @@ -120,17 +121,17 @@ class PartHeader : public Header{ total_size += DisplayContainer("Padding_", Padding_, 3); total_size += DisplayItemInfo("InfoReserved_", InfoReserved_, 3); - total_size += DisplayContainer("TypeSpecific_", TypeSpecific_, 3); + total_size += DisplayContainer("TypeSpecific_", TypeSpecific_, 3); } private: // you need parameters to create the object - int getNumTypeSpecific(size_t header_size){ - return static_cast(( header_size - 40 ) / 8); + int getNumTypeSpecific(size_t header_size) { + return static_cast((header_size - 40) / 8); } - size_t GenerateHeader(char* ptr, size_t offset = 0){ + size_t GenerateHeader(char *ptr, size_t offset = 0) { // modify the order/items only when the structure is changed. // when you change this, don't forget to change copy constructor. size_t cpy_offset = offset; @@ -148,7 +149,7 @@ class PartHeader : public Header{ offset += Write(ptr, offset, InfoReserved_); offset += WriteContainer(ptr, offset, TypeSpecific_); - if ((offset - cpy_offset) != HeaderSize_){ + if ((offset - cpy_offset) != HeaderSize_) { std::cerr << "Part header size is wrong" << HeaderSize_ << " != " << offset - cpy_offset << std::endl; } return offset; @@ -156,7 +157,7 @@ class PartHeader : public Header{ int16_t HeaderType_; int16_t Flags_; - // int32_t HeaderSize_; + // int32_t HeaderSize_; int32_t Format_; const int16_t Reserved_ = 0; int16_t FlowId_; diff --git a/src/bb/image-io/gendc_separator/tools.h b/src/bb/image-io/gendc_separator/tools.h index bddbebd0..7ec8a4fd 100644 --- a/src/bb/image-io/gendc_separator/tools.h +++ b/src/bb/image-io/gendc_separator/tools.h @@ -11,43 +11,43 @@ // this contains signature and version info // ***************************************************************************** -bool isGenDC(char* buf){ +bool isGenDC(char *buf) { int32_t signature; std::memcpy(&signature, buf + SIGNATURE_OFFSET, sizeof(int32_t)); - if (signature != GENDC_SIGNATURE){ + if (signature != GENDC_SIGNATURE) { std::cout << "[LOG ion-kit(gendc-separator)] The data is not genDC format" << std::endl; return false; } return true; } -std::array getGenDCVersion(char* buf){ +std::array getGenDCVersion(char *buf) { std::array version; - for (int i = 0; i < version.size(); ++i){ - std::memcpy(&version.at(i), buf + VERSION_OFFSET + sizeof(int8_t)*i, sizeof(int8_t)); + for (int i = 0; i < version.size(); ++i) { + std::memcpy(&version.at(i), buf + VERSION_OFFSET + sizeof(int8_t) * i, sizeof(int8_t)); } return version; } -int32_t getDescriptorSize(char* buf, const int container_version, std::array& v){ +int32_t getDescriptorSize(char *buf, const int container_version, std::array &v) { int8_t hex_offset = 0x00; int32_t descriptor_size; - try{ + try { std::memcpy(&descriptor_size, buf + (offset_for_version.at(container_version)).at(::descriptor_size), sizeof(int32_t)); - }catch (std::out_of_range& e){ + } catch (std::out_of_range &e) { std::stringstream ss; ss << "ERROR\t" << e.what() << ": " - << "The version of container " - << v.at(0) - hex_offset << "." - << v.at(1) - hex_offset << "." - << v.at(2) - hex_offset << " is not supported."; + << "The version of container " + << v.at(0) - hex_offset << "." + << v.at(1) - hex_offset << "." + << v.at(2) - hex_offset << " is not supported."; const std::string error_message = ss.str(); throw std::out_of_range(error_message); - }catch(std::exception& e){ + } catch (std::exception &e) { throw e; - } + } return descriptor_size; } diff --git a/src/bb/image-io/httplib.h b/src/bb/image-io/httplib.h index d1d8486b..363afe7a 100644 --- a/src/bb/image-io/httplib.h +++ b/src/bb/image-io/httplib.h @@ -81,10 +81,8 @@ #endif #ifndef CPPHTTPLIB_THREAD_POOL_COUNT -#define CPPHTTPLIB_THREAD_POOL_COUNT \ - ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ - ? std::thread::hardware_concurrency() - 1 \ - : 0)) +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) #endif /* @@ -94,11 +92,11 @@ #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif //_CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE -#endif //_CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE #if defined(_MSC_VER) #ifdef _WIN64 @@ -110,19 +108,19 @@ using ssize_t = int; #if _MSC_VER < 1900 #define snprintf _snprintf_s #endif -#endif // _MSC_VER +#endif // _MSC_VER #ifndef S_ISREG #define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) -#endif // S_ISREG +#endif // S_ISREG #ifndef S_ISDIR #define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) -#endif // S_ISDIR +#endif // S_ISDIR #ifndef NOMINMAX #define NOMINMAX -#endif // NOMINMAX +#endif // NOMINMAX #include #include @@ -142,14 +140,14 @@ using ssize_t = int; #ifndef strcasecmp #define strcasecmp _stricmp -#endif // strcasecmp +#endif // strcasecmp using socket_t = SOCKET; #ifdef CPPHTTPLIB_USE_POLL #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif -#else // not _WIN32 +#else // not _WIN32 #include #include @@ -171,7 +169,7 @@ using socket_t = SOCKET; using socket_t = int; #define INVALID_SOCKET (-1) -#endif //_WIN32 +#endif //_WIN32 #include #include @@ -216,7 +214,7 @@ using socket_t = int; #if OPENSSL_VERSION_NUMBER < 0x10100000L #include inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { - return M_ASN1_STRING_data(asn1); + return M_ASN1_STRING_data(asn1); } #endif #endif @@ -238,14 +236,14 @@ namespace httplib { namespace detail { struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), s2.begin(), s2.end(), - [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; -} // namespace detail +} // namespace detail using Headers = std::multimap; @@ -258,44 +256,48 @@ struct Response; using ResponseHandler = std::function; struct MultipartFormData { - std::string name; - std::string content; - std::string filename; - std::string content_type; + std::string name; + std::string content; + std::string filename; + std::string content_type; }; using MultipartFormDataItems = std::vector; using MultipartFormDataMap = std::multimap; class DataSink { public: - DataSink() : os(&sb_), sb_(*this) {} + DataSink() + : os(&sb_), sb_(*this) { + } - DataSink(const DataSink &) = delete; - DataSink &operator=(const DataSink &) = delete; - DataSink(DataSink &&) = delete; - DataSink &operator=(DataSink &&) = delete; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; - std::function write; - std::function done; - std::function is_writable; - std::ostream os; + std::function write; + std::function done; + std::function is_writable; + std::ostream os; private: - class data_sink_streambuf : public std::streambuf { - public: - explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + class data_sink_streambuf : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) + : sink_(sink) { + } - protected: - std::streamsize xsputn(const char *s, std::streamsize n) { - sink_.write(s, static_cast(n)); - return n; - } + protected: + std::streamsize xsputn(const char *s, std::streamsize n) { + sink_.write(s, static_cast(n)); + return n; + } - private: - DataSink &sink_; - }; + private: + DataSink &sink_; + }; - data_sink_streambuf sb_; + data_sink_streambuf sb_; }; using ContentProvider = @@ -312,223 +314,231 @@ using MultipartContentHeader = class ContentReader { public: - using Reader = std::function; - using MultipartReader = std::function; + using Reader = std::function; + using MultipartReader = std::function; - ContentReader(Reader reader, MultipartReader multipart_reader) - : reader_(reader), multipart_reader_(multipart_reader) {} + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(reader), multipart_reader_(multipart_reader) { + } - bool operator()(MultipartContentHeader header, - ContentReceiver receiver) const { - return multipart_reader_(header, receiver); - } + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(header, receiver); + } - bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + bool operator()(ContentReceiver receiver) const { + return reader_(receiver); + } - Reader reader_; - MultipartReader multipart_reader_; + Reader reader_; + MultipartReader multipart_reader_; }; using Range = std::pair; using Ranges = std::vector; struct Request { - std::string method; - std::string path; - Headers headers; - std::string body; - - std::string remote_addr; - int remote_port = -1; - - // for server - std::string version; - std::string target; - Params params; - MultipartFormDataMap files; - Ranges ranges; - Match matches; - - // for client - size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; - ResponseHandler response_handler; - ContentReceiver content_receiver; - size_t content_length = 0; - ContentProvider content_provider; - Progress progress; + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + size_t content_length = 0; + ContentProvider content_provider; + Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl; + const SSL *ssl; #endif - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - bool has_param(const char *key) const; - std::string get_param_value(const char *key, size_t id = 0) const; - size_t get_param_value_count(const char *key) const; + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; - bool is_multipart_form_data() const; + bool is_multipart_form_data() const; - bool has_file(const char *key) const; - MultipartFormData get_file_value(const char *key) const; + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; - // private members... - size_t authorization_count_ = 0; + // private members... + size_t authorization_count_ = 0; }; struct Response { - std::string version; - int status = -1; - std::string reason; - Headers headers; - std::string body; - - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - template - T get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); - - void set_redirect(const char *url, int status = 302); - void set_redirect(const std::string &url, int status = 302); - void set_content(const char *s, size_t n, const char *content_type); - void set_content(std::string s, const char *content_type); - - void set_content_provider( - size_t length, const char *content_type, ContentProvider provider, - const std::function &resource_releaser = nullptr); - - void set_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - void set_chunked_content_provider( - const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); - - Response() = default; - Response(const Response &) = default; - Response &operator=(const Response &) = default; - Response(Response &&) = default; - Response &operator=(Response &&) = default; - ~Response() { - if (content_provider_resource_releaser_) { - content_provider_resource_releaser_(); - } - } - - // private members... - size_t content_length_ = 0; - ContentProvider content_provider_; - std::function content_provider_resource_releaser_; - bool is_chunked_content_provider = false; + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_redirect(const std::string &url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, const char *content_type, ContentProvider provider, + const std::function &resource_releaser = nullptr); + + void set_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + void set_chunked_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + const std::function &resource_releaser = nullptr); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + std::function content_provider_resource_releaser_; + bool is_chunked_content_provider = false; }; class Stream { public: - virtual ~Stream() = default; + virtual ~Stream() = default; - virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; - virtual ssize_t read(char *ptr, size_t size) = 0; - virtual ssize_t write(const char *ptr, size_t size) = 0; - virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; - template - ssize_t write_format(const char *fmt, const Args &... args); - ssize_t write(const char *ptr); - ssize_t write(const std::string &s); + template + ssize_t write_format(const char *fmt, const Args &...args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); }; class TaskQueue { public: - TaskQueue() = default; - virtual ~TaskQueue() = default; + TaskQueue() = default; + virtual ~TaskQueue() = default; - virtual void enqueue(std::function fn) = 0; - virtual void shutdown() = 0; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; - virtual void on_idle(){}; + virtual void on_idle(){}; }; class ThreadPool : public TaskQueue { public: - explicit ThreadPool(size_t n) : shutdown_(false) { - while (n) { - threads_.emplace_back(worker(*this)); - n--; + explicit ThreadPool(size_t n) + : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } } - } - - ThreadPool(const ThreadPool &) = delete; - ~ThreadPool() override = default; - void enqueue(std::function fn) override { - std::unique_lock lock(mutex_); - jobs_.push_back(fn); - cond_.notify_one(); - } + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); } - cond_.notify_all(); + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); - // Join... - for (auto &t : threads_) { - t.join(); + // Join... + for (auto &t : threads_) { + t.join(); + } } - } private: - struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + struct worker { + explicit worker(ThreadPool &pool) + : pool_(pool) { + } - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } - assert(true == static_cast(fn)); - fn(); - } - } + assert(true == static_cast(fn)); + fn(); + } + } - ThreadPool &pool_; - }; - friend struct worker; + ThreadPool &pool_; + }; + friend struct worker; - std::vector threads_; - std::list> jobs_; + std::vector threads_; + std::list> jobs_; - bool shutdown_; + bool shutdown_; - std::condition_variable cond_; - std::mutex mutex_; + std::condition_variable cond_; + std::mutex mutex_; }; using Logger = std::function; @@ -536,670 +546,687 @@ using Logger = std::function; using SocketOptions = std::function; inline void default_socket_options(socket_t sock) { - int yes = 1; + int yes = 1; #ifdef _WIN32 - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); #else #ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); #else - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); #endif #endif } class Server { public: - using Handler = std::function; - using HandlerWithContentReader = std::function; - using Expect100ContinueHandler = - std::function; + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server &Get(const char *pattern, Handler handler); - Server &Post(const char *pattern, Handler handler); - Server &Post(const char *pattern, HandlerWithContentReader handler); - Server &Put(const char *pattern, Handler handler); - Server &Put(const char *pattern, HandlerWithContentReader handler); - Server &Patch(const char *pattern, Handler handler); - Server &Patch(const char *pattern, HandlerWithContentReader handler); - Server &Delete(const char *pattern, Handler handler); - Server &Delete(const char *pattern, HandlerWithContentReader handler); - Server &Options(const char *pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); - bool set_base_dir(const char *dir, const char *mount_point = nullptr); - bool set_mount_point(const char *mount_point, const char *dir); - bool remove_mount_point(const char *mount_point); - void set_file_extension_and_mimetype_mapping(const char *ext, - const char *mime); - void set_file_request_handler(Handler handler); + bool set_base_dir(const char *dir, const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); - void set_error_handler(Handler handler); - void set_expect_100_continue_handler(Expect100ContinueHandler handler); - void set_logger(Logger logger); + void set_error_handler(Handler handler); + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + void set_logger(Logger logger); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_keep_alive_max_count(size_t count); - void set_keep_alive_timeout(time_t sec); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); - void set_idle_interval(time_t sec, time_t usec = 0); + void set_keep_alive_max_count(size_t count); + void set_keep_alive_timeout(time_t sec); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + void set_idle_interval(time_t sec, time_t usec = 0); - void set_payload_max_length(size_t length); + void set_payload_max_length(size_t length); - bool bind_to_port(const char *host, int port, int socket_flags = 0); - int bind_to_any_port(const char *host, int socket_flags = 0); - bool listen_after_bind(); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - bool listen(const char *host, int port, int socket_flags = 0); + bool listen(const char *host, int port, int socket_flags = 0); - bool is_running() const; - void stop(); + bool is_running() const; + void stop(); - std::function new_task_queue; + std::function new_task_queue; protected: - bool process_request(Stream &strm, bool close_connection, - bool &connection_closed, - const std::function &setup_request); - - std::atomic svr_sock_; - size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; - time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; - time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; - size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool process_request(Stream &strm, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; private: - using Handlers = std::vector>; - using HandlersForContentReader = - std::vector>; - - socket_t create_server_socket(const char *host, int port, int socket_flags, - SocketOptions socket_options) const; - int bind_internal(const char *host, int port, int socket_flags); - bool listen_internal(); - - bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(Request &req, Response &res, bool head = false); - bool dispatch_request(Request &req, Response &res, const Handlers &handlers); - bool - dispatch_request_for_content_reader(Request &req, Response &res, - ContentReader content_reader, - const HandlersForContentReader &handlers); - - bool parse_request_line(const char *s, Request &req); - bool write_response(Stream &strm, bool close_connection, const Request &req, - Response &res); - bool write_content_with_provider(Stream &strm, const Request &req, - Response &res, const std::string &boundary, - const std::string &content_type); - bool read_content(Stream &strm, Request &req, Response &res); - bool - read_content_with_content_receiver(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader multipart_header, - ContentReceiver multipart_receiver); - bool read_content_core(Stream &strm, Request &req, Response &res, - ContentReceiver receiver, - MultipartContentHeader mulitpart_header, - ContentReceiver multipart_receiver); - - virtual bool process_and_close_socket(socket_t sock); - - std::atomic is_running_; - std::vector> base_dirs_; - std::map file_extension_and_mimetype_map_; - Handler file_request_handler_; - Handlers get_handlers_; - Handlers post_handlers_; - HandlersForContentReader post_handlers_for_content_reader_; - Handlers put_handlers_; - HandlersForContentReader put_handlers_for_content_reader_; - Handlers patch_handlers_; - HandlersForContentReader patch_handlers_for_content_reader_; - Handlers delete_handlers_; - HandlersForContentReader delete_handlers_for_content_reader_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - Expect100ContinueHandler expect_100_continue_handler_; - - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = default_socket_options; + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers); + bool + dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + const HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool close_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; + + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = default_socket_options; }; enum Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification }; class Result { public: - Result(const std::shared_ptr &res, Error err) - : res_(res), err_(err) {} - operator bool() const { return res_ != nullptr; } - bool operator==(std::nullptr_t) const { return res_ == nullptr; } - bool operator!=(std::nullptr_t) const { return res_ != nullptr; } - const Response &value() const { return *res_; } - const Response &operator*() const { return *res_; } - const Response *operator->() const { return res_.get(); } - Error error() const { return err_; } + Result(const std::shared_ptr &res, Error err) + : res_(res), err_(err) { + } + operator bool() const { + return res_ != nullptr; + } + bool operator==(std::nullptr_t) const { + return res_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return res_ != nullptr; + } + const Response &value() const { + return *res_; + } + const Response &operator*() const { + return *res_; + } + const Response *operator->() const { + return res_.get(); + } + Error error() const { + return err_; + } private: - std::shared_ptr res_; - Error err_; + std::shared_ptr res_; + Error err_; }; class ClientImpl { public: - explicit ClientImpl(const std::string &host); - - explicit ClientImpl(const std::string &host, int port); - - explicit ClientImpl(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - virtual ~ClientImpl(); - - virtual bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - - Result Patch(const char *path, const std::string &body, + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + Result Put(const char *path); + Result Put(const char *path, const std::string &body, + const char *content_type); + Result Put(const char *path, const Headers &headers, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); protected: - struct Socket { - socket_t sock = INVALID_SOCKET; + struct Socket { + socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; + SSL *ssl = nullptr; #endif - bool is_open() const { return sock != INVALID_SOCKET; } - }; + bool is_open() const { + return sock != INVALID_SOCKET; + } + }; - virtual bool create_and_connect_socket(Socket &socket); - virtual void close_socket(Socket &socket, bool process_socket_ret); + virtual bool create_and_connect_socket(Socket &socket); + virtual void close_socket(Socket &socket, bool process_socket_ret); - bool process_request(Stream &strm, const Request &req, Response &res, - bool close_connection); + bool process_request(Stream &strm, const Request &req, Response &res, + bool close_connection); - Error get_last_error() const; + Error get_last_error() const; - // Error state - mutable Error error_ = Error::Success; + // Error state + mutable Error error_ = Error::Success; - // Socket endoint information - const std::string host_; - const int port_; - const std::string host_and_port_; + // Socket endoint information + const std::string host_; + const int port_; + const std::string host_and_port_; - // Current open socket - Socket socket_; - mutable std::mutex socket_mutex_; - std::recursive_mutex request_mutex_; + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; - // Default headers - Headers default_headers_; + // Default headers + Headers default_headers_; - // Settings - std::string client_cert_path_; - std::string client_key_path_; + // Settings + std::string client_cert_path_; + std::string client_key_path_; - time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; - time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - std::string basic_auth_username_; - std::string basic_auth_password_; - std::string bearer_token_auth_token_; + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; + std::string digest_auth_username_; + std::string digest_auth_password_; #endif - bool keep_alive_ = false; - bool follow_location_ = false; + bool keep_alive_ = false; + bool follow_location_ = false; - bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; - SocketOptions socket_options_ = nullptr; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = nullptr; - bool compress_ = false; - bool decompress_ = true; + bool compress_ = false; + bool decompress_ = true; - std::string interface_; + std::string interface_; - std::string proxy_host_; - int proxy_port_ = -1; + std::string proxy_host_; + int proxy_port_ = -1; - std::string proxy_basic_auth_username_; - std::string proxy_basic_auth_password_; - std::string proxy_bearer_token_auth_token_; + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; #endif - Logger logger_; - - void copy_settings(const ClientImpl &rhs) { - client_cert_path_ = rhs.client_cert_path_; - client_key_path_ = rhs.client_key_path_; - connection_timeout_sec_ = rhs.connection_timeout_sec_; - read_timeout_sec_ = rhs.read_timeout_sec_; - read_timeout_usec_ = rhs.read_timeout_usec_; - write_timeout_sec_ = rhs.write_timeout_sec_; - write_timeout_usec_ = rhs.write_timeout_usec_; - basic_auth_username_ = rhs.basic_auth_username_; - basic_auth_password_ = rhs.basic_auth_password_; - bearer_token_auth_token_ = rhs.bearer_token_auth_token_; + Logger logger_; + + void copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; #endif - keep_alive_ = rhs.keep_alive_; - follow_location_ = rhs.follow_location_; - tcp_nodelay_ = rhs.tcp_nodelay_; - socket_options_ = rhs.socket_options_; - compress_ = rhs.compress_; - decompress_ = rhs.decompress_; - interface_ = rhs.interface_; - proxy_host_ = rhs.proxy_host_; - proxy_port_ = rhs.proxy_port_; - proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; - proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; - proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + tcp_nodelay_ = rhs.tcp_nodelay_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; #endif - logger_ = rhs.logger_; - } + logger_ = rhs.logger_; + } private: - socket_t create_client_socket() const; - bool read_response_line(Stream &strm, Response &res); - bool write_request(Stream &strm, const Request &req, bool close_connection); - bool redirect(const Request &req, Response &res); - bool handle_request(Stream &strm, const Request &req, Response &res, - bool close_connection); - void stop_core(); - std::shared_ptr send_with_content_provider( - const char *method, const char *path, const Headers &headers, - const std::string &body, size_t content_length, - ContentProvider content_provider, const char *content_type); - - virtual bool process_socket(Socket &socket, - std::function callback); - virtual bool is_ssl() const; + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool close_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool close_connection); + void stop_core(); + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_socket(Socket &socket, + std::function callback); + virtual bool is_ssl() const; }; class Client { public: - // Universal interface - explicit Client(const char *scheme_host_port); - - explicit Client(const char *scheme_host_port, - const std::string &client_cert_path, - const std::string &client_key_path); - - // HTTP only interface - explicit Client(const std::string &host, int port); - - explicit Client(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); - - ~Client(); - - bool is_valid() const; - - Result Get(const char *path); - Result Get(const char *path, const Headers &headers); - Result Get(const char *path, Progress progress); - Result Get(const char *path, const Headers &headers, Progress progress); - Result Get(const char *path, ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); - Result Get(const char *path, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); - Result Get(const char *path, const Headers &headers, - ResponseHandler response_handler, ContentReceiver content_receiver, - Progress progress); - Result Get(const char *path, ResponseHandler response_handler, - ContentReceiver content_receiver, Progress progress); - - Result Head(const char *path); - Result Head(const char *path, const Headers &headers); - - Result Post(const char *path); - Result Post(const char *path, const std::string &body, - const char *content_type); - Result Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Post(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Post(const char *path, const Params ¶ms); - Result Post(const char *path, const Headers &headers, const Params ¶ms); - Result Post(const char *path, const MultipartFormDataItems &items); - Result Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - Result Put(const char *path); - Result Put(const char *path, const std::string &body, - const char *content_type); - Result Put(const char *path, const Headers &headers, const std::string &body, - const char *content_type); - Result Put(const char *path, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Headers &headers, size_t content_length, - ContentProvider content_provider, const char *content_type); - Result Put(const char *path, const Params ¶ms); - Result Put(const char *path, const Headers &headers, const Params ¶ms); - Result Patch(const char *path, const std::string &body, + // Universal interface + explicit Client(const char *scheme_host_port); + + explicit Client(const char *scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + ~Client(); + + bool is_valid() const; + + Result Get(const char *path); + Result Get(const char *path, const Headers &headers); + Result Get(const char *path, Progress progress); + Result Get(const char *path, const Headers &headers, Progress progress); + Result Get(const char *path, ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const char *path, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const char *path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const char *path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Head(const char *path); + Result Head(const char *path, const Headers &headers); + + Result Post(const char *path); + Result Post(const char *path, const std::string &body, + const char *content_type); + Result Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Post(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Post(const char *path, const Params ¶ms); + Result Post(const char *path, const Headers &headers, const Params ¶ms); + Result Post(const char *path, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const char *path); + Result Put(const char *path, const std::string &body, const char *content_type); - Result Patch(const char *path, const Headers &headers, - const std::string &body, const char *content_type); - Result Patch(const char *path, size_t content_length, + Result Put(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + Result Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type); - Result Patch(const char *path, const Headers &headers, size_t content_length, + Result Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type); + Result Put(const char *path, const Params ¶ms); + Result Put(const char *path, const Headers &headers, const Params ¶ms); + Result Patch(const char *path, const std::string &body, + const char *content_type); + Result Patch(const char *path, const Headers &headers, + const std::string &body, const char *content_type); + Result Patch(const char *path, size_t content_length, + ContentProvider content_provider, const char *content_type); + Result Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type); - Result Delete(const char *path); - Result Delete(const char *path, const std::string &body, - const char *content_type); - Result Delete(const char *path, const Headers &headers); - Result Delete(const char *path, const Headers &headers, - const std::string &body, const char *content_type); + Result Delete(const char *path); + Result Delete(const char *path, const std::string &body, + const char *content_type); + Result Delete(const char *path, const Headers &headers); + Result Delete(const char *path, const Headers &headers, + const std::string &body, const char *content_type); - Result Options(const char *path); - Result Options(const char *path, const Headers &headers); + Result Options(const char *path); + Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res); + bool send(const Request &req, Response &res); - size_t is_socket_open() const; + size_t is_socket_open() const; - void stop(); + void stop(); - void set_default_headers(Headers headers); + void set_default_headers(Headers headers); - void set_tcp_nodelay(bool on); - void set_socket_options(SocketOptions socket_options); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); - void set_connection_timeout(time_t sec, time_t usec = 0); - void set_read_timeout(time_t sec, time_t usec = 0); - void set_write_timeout(time_t sec, time_t usec = 0); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); - void set_basic_auth(const char *username, const char *password); - void set_bearer_token_auth(const char *token); + void set_basic_auth(const char *username, const char *password); + void set_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const char *username, const char *password); + void set_digest_auth(const char *username, const char *password); #endif - void set_keep_alive(bool on); - void set_follow_location(bool on); + void set_keep_alive(bool on); + void set_follow_location(bool on); - void set_compress(bool on); + void set_compress(bool on); - void set_decompress(bool on); + void set_decompress(bool on); - void set_interface(const char *intf); + void set_interface(const char *intf); - void set_proxy(const char *host, int port); - void set_proxy_basic_auth(const char *username, const char *password); - void set_proxy_bearer_token_auth(const char *token); + void set_proxy(const char *host, int port); + void set_proxy_basic_auth(const char *username, const char *password); + void set_proxy_bearer_token_auth(const char *token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const char *username, const char *password); + void set_proxy_digest_auth(const char *username, const char *password); #endif - void set_logger(Logger logger); + void set_logger(Logger logger); - // SSL + // SSL #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Client &set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + Client &set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - Client &set_ca_cert_store(X509_STORE *ca_cert_store); + Client &set_ca_cert_store(X509_STORE *ca_cert_store); - Client &enable_server_certificate_verification(bool enabled); + Client &enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; #endif private: - std::shared_ptr cli_; + std::shared_ptr cli_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool is_ssl_ = false; + bool is_ssl_ = false; #endif -}; // namespace httplib +}; // namespace httplib #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLServer : public Server { public: - SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path = nullptr, - const char *client_ca_cert_dir_path = nullptr); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); - ~SSLServer() override; + ~SSLServer() override; - bool is_valid() const override; + bool is_valid() const override; private: - bool process_and_close_socket(socket_t sock) override; + bool process_and_close_socket(socket_t sock) override; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public ClientImpl { public: - explicit SSLClient(const std::string &host); + explicit SSLClient(const std::string &host); - explicit SSLClient(const std::string &host, int port); + explicit SSLClient(const std::string &host, int port); - explicit SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path); + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key); + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key); - ~SSLClient() override; + ~SSLClient() override; - bool is_valid() const override; + bool is_valid() const override; - void set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path = nullptr); + void set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path = nullptr); - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(X509_STORE *ca_cert_store); - void enable_server_certificate_verification(bool enabled); + void enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX *ssl_context() const; + SSL_CTX *ssl_context() const; private: - bool create_and_connect_socket(Socket &socket) override; - void close_socket(Socket &socket, bool process_socket_ret) override; + bool create_and_connect_socket(Socket &socket) override; + void close_socket(Socket &socket, bool process_socket_ret) override; - bool process_socket(Socket &socket, - std::function callback) override; - bool is_ssl() const override; + bool process_socket(Socket &socket, + std::function callback) override; + bool is_ssl() const override; - bool connect_with_proxy(Socket &sock, Response &res, bool &success); - bool initialize_ssl(Socket &socket); + bool connect_with_proxy(Socket &sock, Response &res, bool &success); + bool initialize_ssl(Socket &socket); - bool load_certs(); + bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; - std::once_flag initialize_cert_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; - std::vector host_components_; + std::vector host_components_; - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - X509_STORE *ca_cert_store_ = nullptr; - bool server_certificate_verification_ = true; - long verify_result_ = 0; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + X509_STORE *ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; + long verify_result_ = 0; - friend class ClientImpl; + friend class ClientImpl; }; #endif @@ -1212,724 +1239,779 @@ class SSLClient : public ClientImpl { namespace detail { inline bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; } inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { - if (i >= s.size()) { return false; } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { return false; } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } } - } - return true; + return true; } inline std::string from_i_to_hex(size_t n) { - const char *charset = "0123456789abcdef"; - std::string ret; - do { - ret = charset[n & 15] + ret; - n >>= 4; - } while (n > 0); - return ret; + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; } inline bool start_with(const std::string &a, const std::string &b) { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { - if (std::tolower(a[i]) != std::tolower(b[i])) { return false; } - } - return true; + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (std::tolower(a[i]) != std::tolower(b[i])) { + return false; + } + } + return true; } inline size_t to_utf8(int code, char *buff) { - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); - buff[1] = static_cast(0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED return 0; - } else if (code < 0x10000) { - buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); - buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[2] = static_cast(0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); - buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); - buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); - buff[3] = static_cast(0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; } // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c inline std::string base64_encode(const std::string &in) { - static const auto lookup = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(in.size()); + std::string out; + out.reserve(in.size()); - int val = 0; - int valb = -6; + int val = 0; + int valb = -6; - for (auto c : in) { - val = (val << 8) + static_cast(c); - valb += 8; - while (valb >= 0) { - out.push_back(lookup[(val >> valb) & 0x3F]); - valb -= 6; + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } } - } - if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } - while (out.size() % 4) { - out.push_back('='); - } + while (out.size() % 4) { + out.push_back('='); + } - return out; + return out; } inline bool is_file(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); } inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); } inline bool is_valid_path(const std::string &path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } + size_t level = 0; + size_t i = 0; - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; } - auto len = i - beg; - assert(len > 0); + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { return false; } - level--; - } else { - level++; - } + auto len = i - beg; + assert(len > 0); - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } } - } - return true; + return true; } inline std::string encode_url(const std::string &s) { - std::string result; - - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "%20"; break; - case '+': result += "%2B"; break; - case '\r': result += "%0D"; break; - case '\n': result += "%0A"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - // case ':': result += "%3A"; break; // ok? probably... - case ';': result += "%3B"; break; - default: - auto c = static_cast(s[i]); - if (c >= 0x80) { - result += '%'; - char hex[4]; - auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); - assert(len == 2); - result.append(hex, static_cast(len)); - } else { - result += s[i]; - } - break; + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } } - } - return result; + return result; } inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { result.append(buff, len); } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += static_cast(val); - i += 2; // '00' + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; } else { - result += s[i]; + result += s[i]; } - } - } else if (convert_plus_to_space && s[i] == '+') { - result += ' '; - } else { - result += s[i]; } - } - return result; + return result; } inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); } inline std::string file_extension(const std::string &path) { - std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, re)) { return m[1].str(); } - return std::string(); + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); } -inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +inline bool is_space_or_tab(char c) { + return c == ' ' || c == '\t'; +} inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { - while (b + left < e && is_space_or_tab(b[left])) { - left++; - } - while (right > 0 && is_space_or_tab(b[right - 1])) { - right--; - } - return std::make_pair(left, right); + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); } inline std::string trim_copy(const std::string &s) { - auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); - return s.substr(r.first, r.second - r.first); + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); } -template void split(const char *b, const char *e, char d, Fn fn) { - size_t i = 0; - size_t beg = 0; +template +void split(const char *b, const char *e, char d, Fn fn) { + size_t i = 0; + size_t beg = 0; - while (e ? (b + i < e) : (b[i] != '\0')) { - if (b[i] == d) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - beg = i + 1; + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + } + i++; } - i++; - } - if (i) { - auto r = trim(b, e, beg, i); - if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } - } + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) - : strm_(strm), fixed_buffer_(fixed_buffer), - fixed_buffer_size_(fixed_buffer_size) {} + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) { + } - const char *ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } } - } - size_t size() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_used_size_; - } else { - return glowable_buffer_.size(); + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } } - } - bool end_with_crlf() const { - auto end = ptr() + size(); - return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; - } + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } - for (size_t i = 0;; i++) { - char byte; - auto n = strm_.read(&byte, 1); + append(byte); - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; + if (byte == '\n') { + break; + } } - } - append(byte); - - if (byte == '\n') { break; } + return true; } - return true; - } - private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } - } - - Stream &strm_; - char *fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; }; inline int close_socket(socket_t sock) { #ifdef _WIN32 - return closesocket(sock); + return closesocket(sock); #else - return close(sock); + return close(sock); #endif } -template inline ssize_t handle_EINTR(T fn) { - ssize_t res = false; - while (true) { - res = fn(); - if (res < 0 && errno == EINTR) { continue; } - break; - } - return res; +template +inline ssize_t handle_EINTR(T fn) { + ssize_t res = false; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + }); #endif } inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLOUT; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); #else - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); - return handle_EINTR([&]() { - return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); - }); + return handle_EINTR([&]() { + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); + }); #endif } inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN | POLLOUT; - - auto timeout = static_cast(sec * 1000 + usec / 1000); - - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); - - if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { - int error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - return res >= 0 && !error; - } - return false; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + return res >= 0 && !error; + } + return false; #else - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - int error = 0; - socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len) >= 0 && - !error; - } - return false; + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); + }); + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; #endif } class SocketStream : public Stream { public: - SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec); - ~SocketStream() override; + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); + ~SocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: - SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec); - ~SSLSocketStream() override; + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #endif class BufferStream : public Stream { public: - BufferStream() = default; - ~BufferStream() override = default; + BufferStream() = default; + ~BufferStream() override = default; - bool is_readable() const override; - bool is_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; - const std::string &get_buffer() const; + const std::string &get_buffer() const; private: - std::string buffer; - size_t position = 0; + std::string buffer; + size_t position = 0; }; inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { - using namespace std::chrono; - auto start = steady_clock::now(); - while (true) { - auto val = select_read(sock, 0, 10000); - if (val < 0) { - return false; - } else if (val == 0) { - auto current = steady_clock::now(); - auto duration = duration_cast(current - start); - auto timeout = keep_alive_timeout_sec * 1000; - if (duration.count() > timeout) { return false; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } else { - return true; + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto duration = duration_cast(current - start); + auto timeout = keep_alive_timeout_sec * 1000; + if (duration.count() > timeout) { + return false; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; + } } - } } -template +template inline bool process_server_socket_core(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { - assert(keep_alive_max_count > 0); - auto ret = false; - auto count = keep_alive_max_count; - while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { - auto close_connection = count == 1; - auto connection_closed = false; - ret = callback(close_connection, connection_closed); - if (!ret || connection_closed) { break; } - count--; - } - return ret; -} - -template + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template inline bool process_server_socket(socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } -template +template socket_t create_socket(const char *host, int port, int socket_flags, bool tcp_nodelay, SocketOptions socket_options, BindOrConnect bind_or_connect) { - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (getaddrinfo(host, service.c_str(), &hints, &result)) { #ifdef __linux__ - res_init(); + res_init(); #endif - return INVALID_SOCKET; - } + return INVALID_SOCKET; + } - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket #ifdef _WIN32 - auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, - nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); - /** - * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 - * and above the socket creation fails on older Windows Systems. - * - * Let's try to create a socket the old way in this case. - * - * Reference: - * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa - * - * WSA_FLAG_NO_HANDLE_INHERIT: - * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with - * SP1, and later - * - */ - if (sock == INVALID_SOCKET) { - sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - } + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } #else - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); #endif - if (sock == INVALID_SOCKET) { continue; } + if (sock == INVALID_SOCKET) { + continue; + } #ifdef __linux__ - if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + continue; + } #endif - if (tcp_nodelay) { - int yes = 1; - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), - sizeof(yes)); - } + if (tcp_nodelay) { + int yes = 1; + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), + sizeof(yes)); + } - if (socket_options) { socket_options(sock); } + if (socket_options) { + socket_options(sock); + } - if (rp->ai_family == AF_INET6) { - int no = 0; - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), - sizeof(no)); - } + if (rp->ai_family == AF_INET6) { + int no = 0; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), + sizeof(no)); + } - // bind or connect - if (bind_or_connect(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + // bind or connect + if (bind_or_connect(sock, *rp)) { + freeaddrinfo(result); + return sock; + } - close_socket(sock); - } + close_socket(sock); + } - freeaddrinfo(result); - return INVALID_SOCKET; + freeaddrinfo(result); + return INVALID_SOCKET; } inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, - nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } inline bool bind_ip_address(socket_t sock, const char *host) { - struct addrinfo hints; - struct addrinfo *result; + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; - if (getaddrinfo(host, "0", &hints, &result)) { return false; } + if (getaddrinfo(host, "0", &hints, &result)) { + return false; + } - auto ret = false; - for (auto rp = result; rp; rp = rp->ai_next) { - const auto &ai = *rp; - if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - ret = true; - break; + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } } - } - freeaddrinfo(result); - return ret; + freeaddrinfo(result); + return ret; } #if !defined _WIN32 && !defined ANDROID @@ -1938,22 +2020,22 @@ inline bool bind_ip_address(socket_t sock, const char *host) { #ifdef USE_IF2IP inline std::string if2ip(const std::string &ifn) { - struct ifaddrs *ifap; - getifaddrs(&ifap); - for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { - if (ifa->ifa_addr && ifn == ifa->ifa_name) { - if (ifa->ifa_addr->sa_family == AF_INET) { - auto sa = reinterpret_cast(ifa->ifa_addr); - char buf[INET_ADDRSTRLEN]; - if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { - freeifaddrs(ifap); - return std::string(buf, INET_ADDRSTRLEN); - } - } - } - } - freeifaddrs(ifap); - return std::string(); + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); + return std::string(); } #endif @@ -1962,1346 +2044,1544 @@ inline socket_t create_client_socket(const char *host, int port, SocketOptions socket_options, time_t timeout_sec, time_t timeout_usec, const std::string &intf, Error &error) { - auto sock = create_socket( - host, port, 0, tcp_nodelay, socket_options, - [&](socket_t sock, struct addrinfo &ai) -> bool { - if (!intf.empty()) { + auto sock = create_socket( + host, port, 0, tcp_nodelay, socket_options, + [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { #ifdef USE_IF2IP - auto ip = if2ip(intf); - if (ip.empty()) { ip = intf; } - if (!bind_ip_address(sock, ip.c_str())) { - error = Error::BindIPAddress; - return false; - } + auto ip = if2ip(intf); + if (ip.empty()) { + ip = intf; + } + if (!bind_ip_address(sock, ip.c_str())) { + error = Error::BindIPAddress; + return false; + } #endif - } + } - set_nonblocking(sock, true); + set_nonblocking(sock, true); - auto ret = - ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); - if (ret < 0) { - if (is_connection_error() || - !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { - close_socket(sock); - error = Error::Connection; - return false; - } - } + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, timeout_usec)) { + close_socket(sock); + error = Error::Connection; + return false; + } + } - set_nonblocking(sock, false); - error = Error::Success; - return true; - }); + set_nonblocking(sock, false); + error = Error::Success; + return true; + }); - if (sock != INVALID_SOCKET) { - error = Error::Success; - } else { - if (error == Error::Success) { error = Error::Connection; } - } + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } - return sock; + return sock; } inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { - if (addr.ss_family == AF_INET) { - port = ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - port = - ntohs(reinterpret_cast(&addr)->sin6_port); - } + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } - std::array ipstr{}; - if (!getnameinfo(reinterpret_cast(&addr), addr_len, - ipstr.data(), static_cast(ipstr.size()), nullptr, - 0, NI_NUMERICHOST)) { - ip = ipstr.data(); - } + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } } inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); - if (!getpeername(sock, reinterpret_cast(&addr), - &addr_len)) { - get_remote_ip_and_port(addr, addr_len, ip, port); - } + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } } inline const char * find_content_type(const std::string &path, const std::map &user_data) { - auto ext = file_extension(path); - - auto it = user_data.find(ext); - if (it != user_data.end()) { return it->second.c_str(); } - - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html" || ext == "htm") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "wasm") { - return "application/wasm"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second.c_str(); + } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; } inline const char *status_message(int status) { - switch (status) { - case 100: return "Continue"; - case 101: return "Switching Protocol"; - case 102: return "Processing"; - case 103: return "Early Hints"; - case 200: return "OK"; - case 201: return "Created"; - case 202: return "Accepted"; - case 203: return "Non-Authoritative Information"; - case 204: return "No Content"; - case 205: return "Reset Content"; - case 206: return "Partial Content"; - case 207: return "Multi-Status"; - case 208: return "Already Reported"; - case 226: return "IM Used"; - case 300: return "Multiple Choice"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 305: return "Use Proxy"; - case 306: return "unused"; - case 307: return "Temporary Redirect"; - case 308: return "Permanent Redirect"; - case 400: return "Bad Request"; - case 401: return "Unauthorized"; - case 402: return "Payment Required"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 405: return "Method Not Allowed"; - case 406: return "Not Acceptable"; - case 407: return "Proxy Authentication Required"; - case 408: return "Request Timeout"; - case 409: return "Conflict"; - case 410: return "Gone"; - case 411: return "Length Required"; - case 412: return "Precondition Failed"; - case 413: return "Payload Too Large"; - case 414: return "URI Too Long"; - case 415: return "Unsupported Media Type"; - case 416: return "Range Not Satisfiable"; - case 417: return "Expectation Failed"; - case 418: return "I'm a teapot"; - case 421: return "Misdirected Request"; - case 422: return "Unprocessable Entity"; - case 423: return "Locked"; - case 424: return "Failed Dependency"; - case 425: return "Too Early"; - case 426: return "Upgrade Required"; - case 428: return "Precondition Required"; - case 429: return "Too Many Requests"; - case 431: return "Request Header Fields Too Large"; - case 451: return "Unavailable For Legal Reasons"; - case 501: return "Not Implemented"; - case 502: return "Bad Gateway"; - case 503: return "Service Unavailable"; - case 504: return "Gateway Timeout"; - case 505: return "HTTP Version Not Supported"; - case 506: return "Variant Also Negotiates"; - case 507: return "Insufficient Storage"; - case 508: return "Loop Detected"; - case 510: return "Not Extended"; - case 511: return "Network Authentication Required"; - - default: - case 500: return "Internal Server Error"; - } + switch (status) { + case 100: + return "Continue"; + case 101: + return "Switching Protocol"; + case 102: + return "Processing"; + case 103: + return "Early Hints"; + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 203: + return "Non-Authoritative Information"; + case 204: + return "No Content"; + case 205: + return "Reset Content"; + case 206: + return "Partial Content"; + case 207: + return "Multi-Status"; + case 208: + return "Already Reported"; + case 226: + return "IM Used"; + case 300: + return "Multiple Choice"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 303: + return "See Other"; + case 304: + return "Not Modified"; + case 305: + return "Use Proxy"; + case 306: + return "unused"; + case 307: + return "Temporary Redirect"; + case 308: + return "Permanent Redirect"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 402: + return "Payment Required"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Payload Too Large"; + case 414: + return "URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Range Not Satisfiable"; + case 417: + return "Expectation Failed"; + case 418: + return "I'm a teapot"; + case 421: + return "Misdirected Request"; + case 422: + return "Unprocessable Entity"; + case 423: + return "Locked"; + case 424: + return "Failed Dependency"; + case 425: + return "Too Early"; + case 426: + return "Upgrade Required"; + case 428: + return "Precondition Required"; + case 429: + return "Too Many Requests"; + case 431: + return "Request Header Fields Too Large"; + case 451: + return "Unavailable For Legal Reasons"; + case 501: + return "Not Implemented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + case 506: + return "Variant Also Negotiates"; + case 507: + return "Insufficient Storage"; + case 508: + return "Loop Detected"; + case 510: + return "Not Extended"; + case 511: + return "Network Authentication Required"; + + default: + case 500: + return "Internal Server Error"; + } } inline bool can_compress_content_type(const std::string &content_type) { - return (!content_type.find("text/") && content_type != "text/event-stream") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; + return (!content_type.find("text/") && content_type != "text/event-stream") || + content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; } -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, + Gzip, + Brotli }; inline EncodingType encoding_type(const Request &req, const Response &res) { - auto ret = - detail::can_compress_content_type(res.get_header_value("Content-Type")); - if (!ret) { return EncodingType::None; } + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } - const auto &s = req.get_header_value("Accept-Encoding"); - (void)(s); + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); #ifdef CPPHTTPLIB_BROTLI_SUPPORT - // TODO: 'Accept-Encoding' has br, not br;q=0 - ret = s.find("br") != std::string::npos; - if (ret) { return EncodingType::Brotli; } + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 - ret = s.find("gzip") != std::string::npos; - if (ret) { return EncodingType::Gzip; } + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } #endif - return EncodingType::None; + return EncodingType::None; } class compressor { public: - virtual ~compressor(){}; + virtual ~compressor(){}; - typedef std::function Callback; - virtual bool compress(const char *data, size_t data_length, bool last, - Callback callback) = 0; + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; }; class decompressor { public: - virtual ~decompressor() {} + virtual ~decompressor() { + } - virtual bool is_valid() const = 0; + virtual bool is_valid() const = 0; - typedef std::function Callback; - virtual bool decompress(const char *data, size_t data_length, - Callback callback) = 0; + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; }; class nocompressor : public compressor { public: - ~nocompressor(){}; + ~nocompressor(){}; - bool compress(const char *data, size_t data_length, bool /*last*/, - Callback callback) override { - if (!data_length) { return true; } - return callback(data, data_length); - } + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override { + if (!data_length) { + return true; + } + return callback(data, data_length); + } }; #ifdef CPPHTTPLIB_ZLIB_SUPPORT class gzip_compressor : public compressor { public: - gzip_compressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, - Z_DEFAULT_STRATEGY) == Z_OK; - } + gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; + } - ~gzip_compressor() { deflateEnd(&strm_); } + ~gzip_compressor() { + deflateEnd(&strm_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - assert(is_valid_); + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + assert(is_valid_); - auto flush = last ? Z_FINISH : Z_NO_FLUSH; + auto flush = last ? Z_FINISH : Z_NO_FLUSH; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - int ret = Z_OK; + int ret = Z_OK; - std::array buff{}; - do { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + do { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = deflate(&strm_, flush); - assert(ret != Z_STREAM_ERROR); + ret = deflate(&strm_, flush); + assert(ret != Z_STREAM_ERROR); - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } while (strm_.avail_out == 0); + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); - assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); - assert(strm_.avail_in == 0); - return true; - } + assert((last && ret == Z_STREAM_END) || (!last && ret == Z_OK)); + assert(strm_.avail_in == 0); + return true; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; class gzip_decompressor : public decompressor { public: - gzip_decompressor() { - std::memset(&strm_, 0, sizeof(strm_)); - strm_.zalloc = Z_NULL; - strm_.zfree = Z_NULL; - strm_.opaque = Z_NULL; - - // 15 is the value of wbits, which should be at the maximum possible value - // to ensure that any gzip stream can be decoded. The offset of 32 specifies - // that the stream type should be automatically detected either gzip or - // deflate. - is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; - } + gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; + } - ~gzip_decompressor() { inflateEnd(&strm_); } + ~gzip_decompressor() { + inflateEnd(&strm_); + } - bool is_valid() const override { return is_valid_; } + bool is_valid() const override { + return is_valid_; + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - assert(is_valid_); + bool decompress(const char *data, size_t data_length, + Callback callback) override { + assert(is_valid_); - int ret = Z_OK; + int ret = Z_OK; - strm_.avail_in = static_cast(data_length); - strm_.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - std::array buff{}; - while (strm_.avail_in > 0) { - strm_.avail_out = buff.size(); - strm_.next_out = reinterpret_cast(buff.data()); + std::array buff{}; + while (strm_.avail_in > 0) { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = inflate(&strm_, Z_NO_FLUSH); - assert(ret != Z_STREAM_ERROR); - switch (ret) { - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: inflateEnd(&strm_); return false; - } + ret = inflate(&strm_, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } - if (!callback(buff.data(), buff.size() - strm_.avail_out)) { - return false; - } - } + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } - return ret == Z_OK || ret == Z_STREAM_END; - } + return ret == Z_OK || ret == Z_STREAM_END; + } private: - bool is_valid_ = false; - z_stream strm_; + bool is_valid_ = false; + z_stream strm_; }; #endif #ifdef CPPHTTPLIB_BROTLI_SUPPORT class brotli_compressor : public compressor { public: - brotli_compressor() { - state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); - } + brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + } - ~brotli_compressor() { BrotliEncoderDestroyInstance(state_); } + ~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); + } - bool compress(const char *data, size_t data_length, bool last, - Callback callback) override { - std::array buff{}; + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } - auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; - auto available_in = data_length; - auto next_in = reinterpret_cast(data); + auto available_out = buff.size(); + auto next_out = buff.data(); - for (;;) { - if (last) { - if (BrotliEncoderIsFinished(state_)) { break; } - } else { - if (!available_in) { break; } - } - - auto available_out = buff.size(); - auto next_out = buff.data(); - - if (!BrotliEncoderCompressStream(state_, operation, &available_in, - &next_in, &available_out, &next_out, - nullptr)) { - return false; - } + if (!BrotliEncoderCompressStream(state_, operation, &available_in, + &next_in, &available_out, &next_out, + nullptr)) { + return false; + } - auto output_bytes = buff.size() - available_out; - if (output_bytes) { - callback(reinterpret_cast(buff.data()), output_bytes); - } - } + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } - return true; - } + return true; + } private: - BrotliEncoderState *state_ = nullptr; + BrotliEncoderState *state_ = nullptr; }; class brotli_decompressor : public decompressor { public: - brotli_decompressor() { - decoder_s = BrotliDecoderCreateInstance(0, 0, 0); - decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT - : BROTLI_DECODER_RESULT_ERROR; - } - - ~brotli_decompressor() { - if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } - } + brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; + } - bool is_valid() const override { return decoder_s; } + ~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } + } - bool decompress(const char *data, size_t data_length, - Callback callback) override { - if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_ERROR) { - return 0; + bool is_valid() const override { + return decoder_s; } - const uint8_t *next_in = (const uint8_t *)data; - size_t avail_in = data_length; - size_t total_out; + bool decompress(const char *data, size_t data_length, + Callback callback) override { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } - decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + const uint8_t *next_in = (const uint8_t *)data; + size_t avail_in = data_length; + size_t total_out; - std::array buff{}; - while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { - char *next_out = buff.data(); - size_t avail_out = buff.size(); + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; - decoder_r = BrotliDecoderDecompressStream( - decoder_s, &avail_in, &next_in, &avail_out, - reinterpret_cast(&next_out), &total_out); + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); - if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); - if (!callback(buff.data(), buff.size() - avail_out)) { return false; } - } + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } - return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || - decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; - } + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; + } private: - BrotliDecoderResult decoder_r; - BrotliDecoderState *decoder_s = nullptr; + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; }; #endif inline bool has_header(const Headers &headers, const char *key) { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline const char *get_header_value(const Headers &headers, const char *key, size_t id = 0, const char *def = nullptr) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second.c_str(); } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; } -template +template inline T get_header_value(const Headers & /*headers*/, const char * /*key*/, - size_t /*id*/ = 0, uint64_t /*def*/ = 0) {} + size_t /*id*/ = 0, uint64_t /*def*/ = 0) { +} -template <> +template<> inline uint64_t get_header_value(const Headers &headers, const char *key, size_t id, uint64_t def) { - auto rng = headers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); - } - return def; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; } -template +template inline bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } - auto p = beg; - while (p < end && *p != ':') { - p++; - } + auto p = beg; + while (p < end && *p != ':') { + p++; + } - if (p == end) { return false; } + if (p == end) { + return false; + } - auto key_end = p; + auto key_end = p; - if (*p++ != ':') { return false; } + if (*p++ != ':') { + return false; + } - while (p < end && is_space_or_tab(*p)) { - p++; - } + while (p < end && is_space_or_tab(*p)) { + p++; + } - if (p < end) { - fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); - return true; - } + if (p < end) { + fn(std::string(beg, key_end), decode_url(std::string(p, end), false)); + return true; + } - return false; + return false; } inline bool read_headers(Stream &strm, Headers &headers) { - const auto bufsiz = 2048; - char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); - for (;;) { - if (!line_reader.getline()) { return false; } + for (;;) { + if (!line_reader.getline()) { + return false; + } - // Check if the line ends with CRLF. - if (line_reader.end_with_crlf()) { - // Blank line indicates end of headers. - if (line_reader.size() == 2) { break; } - } else { - continue; // Skip invalid line. - } + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } + } else { + continue; // Skip invalid line. + } - // Exclude CRLF - auto end = line_reader.ptr() + line_reader.size() - 2; + // Exclude CRLF + auto end = line_reader.ptr() + line_reader.size() - 2; - parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - headers.emplace(std::move(key), std::move(val)); - }); - } + parse_header(line_reader.ptr(), end, + [&](std::string &&key, std::string &&val) { + headers.emplace(std::move(key), std::move(val)); + }); + } - return true; + return true; } inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return false; } + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } - if (!out(buf, static_cast(n))) { return false; } + if (!out(buf, static_cast(n))) { + return false; + } - r += static_cast(n); + r += static_cast(n); - if (progress) { - if (!progress(r, len)) { return false; } + if (progress) { + if (!progress(r, len)) { + return false; + } + } } - } - return true; + return true; } inline void skip_content_with_length(Stream &strm, uint64_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } } inline bool read_content_without_length(Stream &strm, ContentReceiver out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - for (;;) { - auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n < 0) { - return false; - } else if (n == 0) { - return true; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { + return false; + } } - if (!out(buf, static_cast(n))) { return false; } - } - return true; + return true; } inline bool read_content_chunked(Stream &strm, ContentReceiver out) { - const auto bufsiz = 16; - char buf[bufsiz]; + const auto bufsiz = 16; + char buf[bufsiz]; - stream_line_reader line_reader(strm, buf, bufsiz); + stream_line_reader line_reader(strm, buf, bufsiz); - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - unsigned long chunk_len; - while (true) { - char *end_ptr; + unsigned long chunk_len; + while (true) { + char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); - if (end_ptr == line_reader.ptr()) { return false; } - if (chunk_len == ULONG_MAX) { return false; } + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } - if (chunk_len == 0) { break; } + if (chunk_len == 0) { + break; + } - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { - return false; - } + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } - if (!line_reader.getline()) { return false; } + if (!line_reader.getline()) { + return false; + } - if (strcmp(line_reader.ptr(), "\r\n")) { break; } + if (strcmp(line_reader.ptr(), "\r\n")) { + break; + } - if (!line_reader.getline()) { return false; } - } + if (!line_reader.getline()) { + return false; + } + } - if (chunk_len == 0) { - // Reader terminator after chunks - if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) - return false; - } + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } - return true; + return true; } inline bool is_chunked_transfer_encoding(const Headers &headers) { - return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), - "chunked"); + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); } -template +template bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, bool decompress, U callback) { - if (decompress) { - std::string encoding = x.get_header_value("Content-Encoding"); - std::shared_ptr decompressor; + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::shared_ptr decompressor; - if (encoding.find("gzip") != std::string::npos || - encoding.find("deflate") != std::string::npos) { + if (encoding.find("gzip") != std::string::npos || + encoding.find("deflate") != std::string::npos) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } else if (encoding.find("br") != std::string::npos) { + } else if (encoding.find("br") != std::string::npos) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = std::make_shared(); + decompressor = std::make_shared(); #else - status = 415; - return false; + status = 415; + return false; #endif - } + } - if (decompressor) { - if (decompressor->is_valid()) { - ContentReceiver out = [&](const char *buf, size_t n) { - return decompressor->decompress( - buf, n, - [&](const char *buf, size_t n) { return receiver(buf, n); }); - }; - return callback(out); - } else { - status = 500; - return false; - } + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiver out = [&](const char *buf, size_t n) { + return decompressor->decompress( + buf, n, + [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + return callback(out); + } else { + status = 500; + return false; + } + } } - } - ContentReceiver out = [&](const char *buf, size_t n) { - return receiver(buf, n); - }; - return callback(out); + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + return callback(out); } -template +template bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, ContentReceiver receiver, bool decompress) { - return prepare_content_receiver( - x, status, receiver, decompress, [&](const ContentReceiver &out) { - auto ret = true; - auto exceed_payload_max_length = false; - - if (is_chunked_transfer_encoding(x.headers)) { - ret = read_content_chunked(strm, out); - } else if (!has_header(x.headers, "Content-Length")) { - ret = read_content_without_length(strm, out); - } else { - auto len = get_header_value(x.headers, "Content-Length"); - if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; - } else if (len > 0) { - ret = read_content_with_length(strm, len, progress, out); - } - } + return prepare_content_receiver( + x, status, receiver, decompress, [&](const ContentReceiver &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value(x.headers, "Content-Length"); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } - if (!ret) { status = exceed_payload_max_length ? 413 : 400; } - return ret; - }); + if (!ret) { + status = exceed_payload_max_length ? 413 : 400; + } + return ret; + }); } -template +template inline ssize_t write_headers(Stream &strm, const T &info, const Headers &headers) { - ssize_t write_len = 0; - for (const auto &x : info.headers) { - if (x.first == "EXCEPTION_WHAT") { continue; } - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } - for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } + ssize_t write_len = 0; + for (const auto &x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { + continue; + } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } write_len += len; - } - auto len = strm.write("\r\n"); - if (len < 0) { return len; } - write_len += len; - return write_len; + return write_len; } inline bool write_data(Stream &strm, const char *d, size_t l) { - size_t offset = 0; - while (offset < l) { - auto length = strm.write(d + offset, l - offset); - if (length < 0) { return false; } - offset += static_cast(length); - } - return true; + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; } -template +template inline ssize_t write_content(Stream &strm, ContentProvider content_provider, size_t offset, size_t length, T is_shutting_down) { - size_t begin_offset = offset; - size_t end_offset = offset + length; - auto ok = true; - DataSink data_sink; + size_t begin_offset = offset; + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (offset < end_offset && !is_shutting_down()) { - if (!content_provider(offset, end_offset - offset, data_sink)) { - return -1; + while (offset < end_offset && !is_shutting_down()) { + if (!content_provider(offset, end_offset - offset, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } } - if (!ok) { return -1; } - } - return static_cast(offset - begin_offset); + return static_cast(offset - begin_offset); } -template +template inline ssize_t write_content_without_length(Stream &strm, ContentProvider content_provider, T is_shutting_down) { - size_t offset = 0; - auto data_available = true; - auto ok = true; - DataSink data_sink; + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - offset += l; - if (!write_data(strm, d, l)) { ok = false; } - } - }; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { + ok = false; + } + } + }; - data_sink.done = [&](void) { data_available = false; }; + data_sink.done = [&](void) { data_available = false; }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } - return static_cast(offset); + return static_cast(offset); } -template +template inline ssize_t write_content_chunked(Stream &strm, ContentProvider content_provider, T is_shutting_down, U &compressor) { - size_t offset = 0; - auto data_available = true; - ssize_t total_written_length = 0; - auto ok = true; - DataSink data_sink; - - data_sink.write = [&](const char *d, size_t l) { - if (!ok) { return; } - - data_available = l > 0; - offset += l; - - std::string payload; - if (!compressor.compress(d, l, false, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - }; - - data_sink.done = [&](void) { - if (!ok) { return; } - - data_available = false; - - std::string payload; - if (!compressor.compress(nullptr, 0, true, - [&](const char *data, size_t data_len) { - payload.append(data, data_len); - return true; - })) { - ok = false; - return; - } - - if (!payload.empty()) { - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (write_data(strm, chunk.data(), chunk.size())) { - total_written_length += chunk.size(); - } else { - ok = false; - return; - } - } - - static const std::string done_marker("0\r\n\r\n"); - if (write_data(strm, done_marker.data(), done_marker.size())) { - total_written_length += done_marker.size(); - } else { - ok = false; - } - }; + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + auto ok = true; + DataSink data_sink; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + data_sink.write = [&](const char *d, size_t l) { + if (!ok) { + return; + } + + data_available = l > 0; + offset += l; + + std::string payload; + if (!compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } + }; - while (data_available && !is_shutting_down()) { - if (!content_provider(offset, 0, data_sink)) { return -1; } - if (!ok) { return -1; } - } + data_sink.done = [&](void) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (write_data(strm, chunk.data(), chunk.size())) { + total_written_length += chunk.size(); + } else { + ok = false; + return; + } + } - return total_written_length; + static const std::string done_marker("0\r\n\r\n"); + if (write_data(strm, done_marker.data(), done_marker.size())) { + total_written_length += done_marker.size(); + } else { + ok = false; + } + }; + + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { + return -1; + } + if (!ok) { + return -1; + } + } + + return total_written_length; } -template +template inline bool redirect(T &cli, const Request &req, Response &res, const std::string &path) { - Request new_req = req; - new_req.path = path; - new_req.redirect_count -= 1; - - if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { - new_req.method = "GET"; - new_req.body.clear(); - new_req.headers.clear(); - } + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } - Response new_res; + Response new_res; - auto ret = cli.send(new_req, new_res); - if (ret) { res = new_res; } - return ret; + auto ret = cli.send(new_req, new_res); + if (ret) { + res = new_res; + } + return ret; } inline std::string params_to_query_str(const Params ¶ms) { - std::string query; + std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } - query += it->first; - query += "="; - query += encode_url(it->second); - } - return query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_url(it->second); + } + return query; } inline void parse_query_text(const std::string &s, Params ¶ms) { - split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char *b2, const char *e2) { - if (key.empty()) { - key.assign(b2, e2); - } else { - val.assign(b2, e2); - } - }); + split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); - if (!key.empty()) { - params.emplace(decode_url(key, true), decode_url(val, true)); - } - }); + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); } inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { return false; } - boundary = content_type.substr(pos + 9); - if (boundary.length() >= 2 && boundary.front() == '"' && - boundary.back() == '"') { - boundary = boundary.substr(1, boundary.size() - 2); - } - return !boundary.empty(); + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { + return false; + } + boundary = content_type.substr(pos + 9); + if (boundary.length() >= 2 && boundary.front() == '"' && + boundary.back() == '"') { + boundary = boundary.substr(1, boundary.size() - 2); + } + return !boundary.empty(); } inline bool parse_range_header(const std::string &s, Ranges &ranges) { - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re_first_range)) { - auto pos = static_cast(m.position(1)); - auto len = static_cast(m.length(1)); - bool all_valid_ranges = true; - split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { - if (!all_valid_ranges) return; - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch cm; - if (std::regex_match(b, e, cm, re_another_range)) { - ssize_t first = -1; - if (!cm.str(1).empty()) { - first = static_cast(std::stoll(cm.str(1))); - } - - ssize_t last = -1; - if (!cm.str(2).empty()) { - last = static_cast(std::stoll(cm.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - all_valid_ranges = false; - return; - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); - return all_valid_ranges; - } - return false; + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; } class MultipartFormDataParser { public: - MultipartFormDataParser() = default; - - void set_boundary(std::string &&boundary) { boundary_ = boundary; } - - bool is_valid() const { return is_valid_; } - - template - bool parse(const char *buf, size_t n, T content_callback, U header_callback) { - - static const std::regex re_content_disposition( - "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" - "\"(.*?)\")?\\s*$", - std::regex_constants::icase); - static const std::string dash_ = "--"; - static const std::string crlf_ = "\r\n"; - - buf_.append(buf, n); // TODO: performance improvement - - while (!buf_.empty()) { - switch (state_) { - case 0: { // Initial boundary - auto pattern = dash_ + boundary_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - auto pos = buf_.find(pattern); - if (pos != 0) { return false; } - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - state_ = 1; - break; - } - case 1: { // New entry - clear_file_info(); - state_ = 2; - break; - } - case 2: { // Headers - auto pos = buf_.find(crlf_); - while (pos != std::string::npos) { - // Empty line - if (pos == 0) { - if (!header_callback(file_)) { - is_valid_ = false; - return false; - } - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 3; - break; - } - - static const std::string header_name = "content-type:"; - const auto header = buf_.substr(0, pos); - if (start_with(header, header_name)) { - file_.content_type = trim_copy(header.substr(header_name.size())); - } else { - std::smatch m; - if (std::regex_match(header, m, re_content_disposition)) { - file_.name = m[1]; - file_.filename = m[2]; - } - } - - buf_.erase(0, pos + crlf_.size()); - off_ += pos + crlf_.size(); - pos = buf_.find(crlf_); - } - if (state_ != 3) { return true; } - break; - } - case 3: { // Body - { - auto pattern = crlf_ + dash_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos == std::string::npos) { - pos = buf_.size(); - while (pos > 0) { - auto c = buf_[pos - 1]; - if (c != '\r' && c != '\n' && c != '-') { break; } - pos--; - } - } + MultipartFormDataParser() = default; - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; - } + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + } - off_ += pos; - buf_.erase(0, pos); - } + bool is_valid() const { + return is_valid_; + } - { - auto pattern = crlf_ + dash_ + boundary_; - if (pattern.size() > buf_.size()) { return true; } - - auto pos = buf_.find(pattern); - if (pos != std::string::npos) { - if (!content_callback(buf_.data(), pos)) { - is_valid_ = false; - return false; + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + static const std::string dash_ = "--"; + static const std::string crlf_ = "\r\n"; + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + auto pos = buf_.find(pattern); + if (pos != 0) { + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + static const std::string header_name = "content-type:"; + const auto header = buf_.substr(0, pos); + if (start_with(header, header_name)) { + file_.content_type = trim_copy(header.substr(header_name.size())); + } else { + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { + pos = buf_.size(); + while (pos > 0) { + auto c = buf_[pos - 1]; + if (c != '\r' && c != '\n' && c != '-') { + break; + } + pos--; + } + } + + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { + return true; + } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, crlf_.size(), crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { + return true; + } + if (buf_.compare(0, pattern.size(), pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; } - - off_ += pos + pattern.size(); - buf_.erase(0, pos + pattern.size()); - state_ = 4; - } else { - if (!content_callback(buf_.data(), pattern.size())) { - is_valid_ = false; - return false; } - - off_ += pattern.size(); - buf_.erase(0, pattern.size()); - } - } - break; - } - case 4: { // Boundary - if (crlf_.size() > buf_.size()) { return true; } - if (buf_.compare(0, crlf_.size(), crlf_) == 0) { - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); - state_ = 1; - } else { - auto pattern = dash_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - if (buf_.compare(0, pattern.size(), pattern) == 0) { - buf_.erase(0, pattern.size()); - off_ += pattern.size(); - is_valid_ = true; - state_ = 5; - } else { - return true; - } } - break; - } - case 5: { // Done - is_valid_ = false; - return false; - } - } - } - return true; - } + return true; + } private: - void clear_file_info() { - file_.name.clear(); - file_.filename.clear(); - file_.content_type.clear(); - } - - std::string boundary_; - - std::string buf_; - size_t state_ = 0; - bool is_valid_ = false; - size_t off_ = 0; - MultipartFormData file_; + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + bool is_valid_ = false; + size_t off_ = 0; + MultipartFormData file_; }; inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; } inline std::string make_multipart_data_boundary() { - static const char data[] = - "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - std::random_device seed_gen; - std::mt19937 engine(seed_gen()); + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); - std::string result = "--cpp-httplib-multipart-data-"; + std::string result = "--cpp-httplib-multipart-data-"; - for (auto i = 0; i < 16; i++) { - result += data[engine() % (sizeof(data) - 1)]; - } + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } - return result; + return result; } inline std::pair get_range_offset_and_length(const Request &req, size_t content_length, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.first == -1 && r.second == -1) { - return std::make_pair(0, content_length); - } + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } - auto slen = static_cast(content_length); + auto slen = static_cast(content_length); - if (r.first == -1) { - r.first = slen - r.second; - r.second = slen - 1; - } + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } - if (r.second == -1) { r.second = slen - 1; } + if (r.second == -1) { + r.second = slen - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) { - std::string field = "bytes "; - field += std::to_string(offset); - field += "-"; - field += std::to_string(offset + length - 1); - field += "/"; - field += std::to_string(content_length); - return field; + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; } -template +template bool process_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type, SToken stoken, CToken ctoken, Content content) { - for (size_t i = 0; i < req.ranges.size(); i++) { - ctoken("--"); - stoken(boundary); - ctoken("\r\n"); - if (!content_type.empty()) { - ctoken("Content-Type: "); - stoken(content_type); - ctoken("\r\n"); - } + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } - auto offsets = get_range_offset_and_length(req, res.body.size(), i); - auto offset = offsets.first; - auto length = offsets.second; + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; - ctoken("Content-Range: "); - stoken(make_content_range_header_field(offset, length, res.body.size())); - ctoken("\r\n"); - ctoken("\r\n"); - if (!content(offset, length)) { return false; } - ctoken("\r\n"); - } + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { + return false; + } + ctoken("\r\n"); + } - ctoken("--"); - stoken(boundary); - ctoken("--\r\n"); + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); - return true; + return true; } inline std::string make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - std::string data; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data += token; }, - [&](const char *token) { data += token; }, - [&](size_t offset, size_t length) { - data += res.body.substr(offset, length); - return true; - }); + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); - return data; + return data; } inline size_t get_multipart_ranges_data_length(const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - size_t data_length = 0; - - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data_length += token.size(); }, - [&](const char *token) { data_length += strlen(token); }, - [&](size_t /*offset*/, size_t length) { - data_length += length; - return true; - }); + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); - return data_length; + return data_length; } -template +template inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type, T is_shutting_down) { - return process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { strm.write(token); }, - [&](const char *token) { strm.write(token); }, - [&](size_t offset, size_t length) { - return write_content(strm, res.content_provider_, offset, length, - is_shutting_down) >= 0; - }); + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down) >= 0; + }); } inline std::pair get_range_offset_and_length(const Request &req, const Response &res, size_t index) { - auto r = req.ranges[index]; + auto r = req.ranges[index]; - if (r.second == -1) { - r.second = static_cast(res.content_length_) - 1; - } + if (r.second == -1) { + r.second = static_cast(res.content_length_) - 1; + } - return std::make_pair(r.first, r.second - r.first + 1); + return std::make_pair(r.first, r.second - r.first + 1); } inline bool expect_content(const Request &req) { - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || - req.method == "PRI" || req.method == "DELETE") { - return true; - } - // TODO: check if Content-Length is set - return false; + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; } inline bool has_crlf(const char *s) { - auto p = s; - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; - } - return false; + auto p = s; + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -template +template inline std::string message_digest(const std::string &s, Init init, Update update, Final final, size_t digest_length) { - using namespace std; + using namespace std; - std::vector md(digest_length, 0); - CTX ctx; - init(&ctx); - update(&ctx, s.data(), s.size()); - final(md.data(), &ctx); + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); - stringstream ss; - for (auto c : md) { - ss << setfill('0') << setw(2) << hex << (unsigned int)c; - } - return ss.str(); + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); } inline std::string MD5(const std::string &s) { - return message_digest(s, MD5_Init, MD5_Update, MD5_Final, - MD5_DIGEST_LENGTH); + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); } inline std::string SHA_256(const std::string &s) { - return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, - SHA256_DIGEST_LENGTH); + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); } inline std::string SHA_512(const std::string &s) { - return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, - SHA512_DIGEST_LENGTH); + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); } #endif @@ -3310,37 +3590,41 @@ inline std::string SHA_512(const std::string &s) { // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store inline bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } + if (!hStore) { + return false; + } - PCCERT_CONTEXT pContext = NULL; - while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); + PCCERT_CONTEXT pContext = NULL; + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } } - } - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); - return true; + return true; } #endif class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { WSACleanup(); } + ~WSInit() { + WSACleanup(); + } }; static WSInit wsinit_; @@ -3351,340 +3635,355 @@ inline std::pair make_digest_authentication_header( const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, const std::string &username, const std::string &password, bool is_proxy = false) { - using namespace std; + using namespace std; - string nc; - { - stringstream ss; - ss << setfill('0') << setw(8) << hex << cnonce_count; - nc = ss.str(); - } + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } - auto qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else { - qop = "auth"; - } + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } - string response; - { - auto H = algo == "SHA-256" - ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : + detail::MD5; - auto A1 = username + ":" + auth.at("realm") + ":" + password; + auto A1 = username + ":" + auth.at("realm") + ":" + password; - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + - "\", response=\"" + response + "\""; + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + + "\", response=\"" + response + "\""; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } #endif inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { - auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; - if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); - auto s = res.get_header_value(auth_key); - auto pos = s.find(' '); - if (pos != std::string::npos) { - auto type = s.substr(0, pos); - if (type == "Basic") { - return false; - } else if (type == "Digest") { - s = s.substr(pos + 1); - auto beg = std::sregex_iterator(s.begin(), s.end(), re); - for (auto i = beg; i != std::sregex_iterator(); ++i) { - auto m = *i; - auto key = s.substr(static_cast(m.position(1)), - static_cast(m.length(1))); - auto val = m.length(2) > 0 - ? s.substr(static_cast(m.position(2)), - static_cast(m.length(2))) - : s.substr(static_cast(m.position(3)), - static_cast(m.length(3))); - auth[key] = val; + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) : + s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } } - return true; - } } - } - return false; + return false; } // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 inline std::string random_string(size_t length) { - auto randchar = []() -> char { - const char charset[] = "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[static_cast(rand()) % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; } class ContentProviderAdapter { public: - explicit ContentProviderAdapter( - ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) { + } - bool operator()(size_t offset, size_t, DataSink &sink) { - return content_provider_(offset, sink); - } + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } private: - ContentProviderWithoutLength content_provider_; + ContentProviderWithoutLength content_provider_; }; -} // namespace detail +} // namespace detail // Header utilities inline std::pair make_range_header(Ranges ranges) { - std::string field = "bytes="; - auto i = 0; - for (auto r : ranges) { - if (i != 0) { field += ", "; } - if (r.first != -1) { field += std::to_string(r.first); } - field += '-'; - if (r.second != -1) { field += std::to_string(r.second); } - i++; - } - return std::make_pair("Range", field); + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", field); } inline std::pair make_basic_authentication_header(const std::string &username, const std::string &password, bool is_proxy = false) { - auto field = "Basic " + detail::base64_encode(username + ":" + password); - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } inline std::pair make_bearer_token_authentication_header(const std::string &token, bool is_proxy = false) { - auto field = "Bearer " + token; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } // Request implementation inline bool Request::has_header(const char *key) const { - return detail::has_header(headers, key); + return detail::has_header(headers, key); } inline std::string Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Request::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Request::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Request::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline bool Request::has_param(const char *key) const { - return params.find(key) != params.end(); + return params.find(key) != params.end(); } inline std::string Request::get_param_value(const char *key, size_t id) const { - auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); } inline size_t Request::get_param_value_count(const char *key) const { - auto r = params.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline bool Request::is_multipart_form_data() const { - const auto &content_type = get_header_value("Content-Type"); - return !content_type.find("multipart/form-data"); + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); } inline bool Request::has_file(const char *key) const { - return files.find(key) != files.end(); + return files.find(key) != files.end(); } inline MultipartFormData Request::get_file_value(const char *key) const { - auto it = files.find(key); - if (it != files.end()) { return it->second; } - return MultipartFormData(); + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); } // Response implementation inline bool Response::has_header(const char *key) const { - return headers.find(key) != headers.end(); + return headers.find(key) != headers.end(); } inline std::string Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, id, ""); } -template +template inline T Response::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, 0); + return detail::get_header_value(headers, key, id, 0); } inline size_t Response::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); } inline void Response::set_header(const char *key, const char *val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } } inline void Response::set_header(const char *key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { - headers.emplace(key, val); - } + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } } inline void Response::set_redirect(const char *url, int stat) { - if (!detail::has_crlf(url)) { - set_header("Location", url); - if (300 <= stat && stat < 400) { - this->status = stat; - } else { - this->status = 302; + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = 302; + } } - } } inline void Response::set_redirect(const std::string &url, int stat) { - set_redirect(url.c_str(), stat); + set_redirect(url.c_str(), stat); } inline void Response::set_content(const char *s, size_t n, const char *content_type) { - body.assign(s, n); - set_header("Content-Type", content_type); + body.assign(s, n); + set_header("Content-Type", content_type); } inline void Response::set_content(std::string s, const char *content_type) { - body = std::move(s); - set_header("Content-Type", content_type); + body = std::move(s); + set_header("Content-Type", content_type); } inline void Response::set_content_provider(size_t in_length, const char *content_type, ContentProvider provider, const std::function &resource_releaser) { - assert(in_length > 0); - set_header("Content-Type", content_type); - content_length_ = in_length; - content_provider_ = std::move(provider); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + assert(in_length > 0); + set_header("Content-Type", content_type); + content_length_ = in_length; + content_provider_ = std::move(provider); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_content_provider(const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = false; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } inline void Response::set_chunked_content_provider( const char *content_type, ContentProviderWithoutLength provider, const std::function &resource_releaser) { - set_header("Content-Type", content_type); - content_length_ = 0; - content_provider_ = detail::ContentProviderAdapter(std::move(provider)); - content_provider_resource_releaser_ = resource_releaser; - is_chunked_content_provider = true; + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = true; } // Rstream implementation inline ssize_t Stream::write(const char *ptr) { - return write(ptr, strlen(ptr)); + return write(ptr, strlen(ptr)); } inline ssize_t Stream::write(const std::string &s) { - return write(s.data(), s.size()); + return write(s.data(), s.size()); } -template -inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { - const auto bufsiz = 2048; - std::array buf; +template +inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { + const auto bufsiz = 2048; + std::array buf; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); + auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); #else - auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); #endif - if (sn <= 0) { return sn; } + if (sn <= 0) { + return sn; + } - auto n = static_cast(sn); + auto n = static_cast(sn); - if (n >= buf.size() - 1) { - std::vector glowable_buf(buf.size()); + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); - while (n >= glowable_buf.size() - 1) { - glowable_buf.resize(glowable_buf.size() * 2); + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), - glowable_buf.size() - 1, fmt, - args...)); + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); #else - n = static_cast( - snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); #endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); } - return write(&glowable_buf[0], n); - } else { - return write(buf.data(), n); - } } namespace detail { @@ -3697,75 +3996,88 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) {} + write_timeout_usec_(write_timeout_usec) { +} -inline SocketStream::~SocketStream() {} +inline SocketStream::~SocketStream() { +} inline bool SocketStream::is_readable() const { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SocketStream::is_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; } inline ssize_t SocketStream::read(char *ptr, size_t size) { - if (!is_readable()) { return -1; } + if (!is_readable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return recv(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return recv(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return recv(sock_, ptr, size, 0); }); #endif } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!is_writable()) { + return -1; + } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return send(sock_, ptr, static_cast(size), 0); + if (size > static_cast((std::numeric_limits::max)())) { + return -1; + } + return send(sock_, ptr, static_cast(size), 0); #else - return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); + return handle_EINTR([&]() { return send(sock_, ptr, size, 0); }); #endif } inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - return detail::get_remote_ip_and_port(sock_, ip, port); + return detail::get_remote_ip_and_port(sock_, ip, port); } // Buffer stream implementation -inline bool BufferStream::is_readable() const { return true; } +inline bool BufferStream::is_readable() const { + return true; +} -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_writable() const { + return true; +} inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER <= 1900 - auto len_read = buffer._Copy_s(ptr, size, size, position); + auto len_read = buffer._Copy_s(ptr, size, size, position); #else - auto len_read = buffer.copy(ptr, size, position); + auto len_read = buffer.copy(ptr, size, position); #endif - position += static_cast(len_read); - return static_cast(len_read); + position += static_cast(len_read); + return static_cast(len_read); } inline ssize_t BufferStream::write(const char *ptr, size_t size) { - buffer.append(ptr, size); - return static_cast(size); + buffer.append(ptr, size); + return static_cast(size); } inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, - int & /*port*/) const {} + int & /*port*/) const { +} -inline const std::string &BufferStream::get_buffer() const { return buffer; } +inline const std::string &BufferStream::get_buffer() const { + return buffer; +} -} // namespace detail +} // namespace detail // HTTP server implementation inline Server::Server() @@ -3773,1182 +4085,1279 @@ inline Server::Server() [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }), svr_sock_(INVALID_SOCKET), is_running_(false) { #ifdef __linux__ - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif } -inline Server::~Server() {} +inline Server::~Server() { +} inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Post(const char *pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Put(const char *pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Patch(const char *pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Delete(const char *pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); - return *this; + delete_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } inline bool Server::set_base_dir(const char *dir, const char *mount_point) { - return set_mount_point(mount_point, dir); + return set_mount_point(mount_point, dir); } inline bool Server::set_mount_point(const char *mount_point, const char *dir) { - if (detail::is_dir(dir)) { - std::string mnt = mount_point ? mount_point : "/"; - if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.emplace_back(mnt, dir); - return true; + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } } - } - return false; + return false; } inline bool Server::remove_mount_point(const char *mount_point) { - for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { - if (it->first == mount_point) { - base_dirs_.erase(it); - return true; + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } } - } - return false; + return false; } inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, const char *mime) { - file_extension_and_mimetype_map_[ext] = mime; + file_extension_and_mimetype_map_[ext] = mime; } inline void Server::set_file_request_handler(Handler handler) { - file_request_handler_ = std::move(handler); + file_request_handler_ = std::move(handler); } inline void Server::set_error_handler(Handler handler) { - error_handler_ = std::move(handler); + error_handler_ = std::move(handler); } -inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void Server::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } +inline void Server::set_logger(Logger logger) { + logger_ = std::move(logger); +} inline void Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { - expect_100_continue_handler_ = std::move(handler); + expect_100_continue_handler_ = std::move(handler); } inline void Server::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; + keep_alive_max_count_ = count; } inline void Server::set_keep_alive_timeout(time_t sec) { - keep_alive_timeout_sec_ = sec; + keep_alive_timeout_sec_ = sec; } inline void Server::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void Server::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void Server::set_idle_interval(time_t sec, time_t usec) { - idle_interval_sec_ = sec; - idle_interval_usec_ = usec; + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; } inline void Server::set_payload_max_length(size_t length) { - payload_max_length_ = length; + payload_max_length_ = length; } inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { - if (bind_internal(host, port, socket_flags) < 0) return false; - return true; + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; } inline int Server::bind_to_any_port(const char *host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + return bind_internal(host, 0, socket_flags); } -inline bool Server::listen_after_bind() { return listen_internal(); } +inline bool Server::listen_after_bind() { + return listen_internal(); +} inline bool Server::listen(const char *host, int port, int socket_flags) { - return bind_to_port(host, port, socket_flags) && listen_internal(); + return bind_to_port(host, port, socket_flags) && listen_internal(); } -inline bool Server::is_running() const { return is_running_; } +inline bool Server::is_running() const { + return is_running_; +} inline void Server::stop() { - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); - detail::shutdown_socket(sock); - detail::close_socket(sock); - } + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } } inline bool Server::parse_request_line(const char *s, Request &req) { - const static std::regex re( - "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " - "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[5]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3], false); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { detail::parse_query_text(m[4], req.params); } + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { + detail::parse_query_text(m[4], req.params); + } - return true; - } + return true; + } - return false; + return false; } inline bool Server::write_response(Stream &strm, bool close_connection, const Request &req, Response &res) { - assert(res.status != -1); + assert(res.status != -1); - if (400 <= res.status && error_handler_) { error_handler_(req, res); } - - detail::BufferStream bstrm; + if (400 <= res.status && error_handler_) { + error_handler_(req, res); + } - // Response line - if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, - detail::status_message(res.status))) { - return false; - } + detail::BufferStream bstrm; - // Headers - if (close_connection || req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } else { - std::stringstream ss; - ss << "timeout=" << keep_alive_timeout_sec_ - << ", max=" << keep_alive_max_count_; - res.set_header("Keep-Alive", ss.str()); - } + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } - if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { - res.set_header("Content-Type", "text/plain"); - } + // Headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::stringstream ss; + ss << "timeout=" << keep_alive_timeout_sec_ + << ", max=" << keep_alive_max_count_; + res.set_header("Keep-Alive", ss.str()); + } - if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { - res.set_header("Accept-Ranges", "bytes"); - } + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + res.set_header("Content-Type", "text/plain"); + } - std::string content_type; - std::string boundary; + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } - if (req.ranges.size() > 1) { - boundary = detail::make_multipart_data_boundary(); + std::string content_type; + std::string boundary; - auto it = res.headers.find("Content-Type"); - if (it != res.headers.end()) { - content_type = it->second; - res.headers.erase(it); - } + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); - res.headers.emplace("Content-Type", - "multipart/byteranges; boundary=" + boundary); - } + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } - auto type = detail::encoding_type(req, res); + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } - if (res.body.empty()) { - if (res.content_length_ > 0) { - size_t length = 0; - if (req.ranges.empty()) { - length = res.content_length_; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.content_length_); - res.set_header("Content-Range", content_range); - } else { - length = detail::get_multipart_ranges_data_length(req, res, boundary, - content_type); - } - res.set_header("Content-Length", std::to_string(length)); - } else { - if (res.content_provider_) { - if (res.is_chunked_content_provider) { - res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - } - } - } else { - res.set_header("Content-Length", "0"); - } - } - } else { - if (req.ranges.empty()) { - ; - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.body.size(), 0); - auto offset = offsets.first; - auto length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.body.size()); - res.set_header("Content-Range", content_range); - res.body = res.body.substr(offset, length); + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } else { + res.set_header("Content-Length", "0"); + } + } } else { - res.body = - detail::make_multipart_ranges_data(req, res, boundary, content_type); - } + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } - if (type != detail::EncodingType::None) { - std::shared_ptr compressor; + if (type != detail::EncodingType::None) { + std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "gzip"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "gzip"); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); - res.set_header("Content-Encoding", "brotli"); + compressor = std::make_shared(); + res.set_header("Content-Encoding", "brotli"); #endif - } + } + + if (compressor) { + std::string compressed; - if (compressor) { - std::string compressed; + if (!compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + return false; + } - if (!compressor->compress(res.body.data(), res.body.size(), true, - [&](const char *data, size_t data_len) { - compressed.append(data, data_len); - return true; - })) { - return false; + res.body.swap(compressed); + } } - res.body.swap(compressed); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length); - } - - if (!detail::write_headers(bstrm, res, Headers())) { return false; } + if (!detail::write_headers(bstrm, res, Headers())) { + return false; + } - // Flush buffer - auto &data = bstrm.get_buffer(); - strm.write(data.data(), data.size()); + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); - // Body - auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!strm.write(res.body)) { ret = false; } - } else if (res.content_provider_) { - if (!write_content_with_provider(strm, req, res, boundary, - content_type)) { - ret = false; - } + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { + ret = false; + } + } else if (res.content_provider_) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + ret = false; + } + } } - } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return ret; + return ret; } inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { - auto is_shutting_down = [this]() { - return this->svr_sock_ == INVALID_SOCKET; - }; - - if (res.content_length_ > 0) { - if (req.ranges.empty()) { - if (detail::write_content(strm, res.content_provider_, 0, - res.content_length_, is_shutting_down) < 0) { - return false; - } - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length_, 0); - auto offset = offsets.first; - auto length = offsets.second; - if (detail::write_content(strm, res.content_provider_, offset, length, - is_shutting_down) < 0) { - return false; - } + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider_, offset, length, + is_shutting_down) < 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, is_shutting_down)) { + return false; + } + } } else { - if (!detail::write_multipart_ranges_data( - strm, req, res, boundary, content_type, is_shutting_down)) { - return false; - } - } - } else { - if (res.is_chunked_content_provider) { - auto type = detail::encoding_type(req, res); + if (res.is_chunked_content_provider) { + auto type = detail::encoding_type(req, res); - std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + std::shared_ptr compressor; + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else { - compressor = std::make_shared(); - } - assert(compressor != nullptr); + } else { + compressor = std::make_shared(); + } + assert(compressor != nullptr); - if (detail::write_content_chunked(strm, res.content_provider_, - is_shutting_down, *compressor) < 0) { - return false; - } - } else { - if (detail::write_content_without_length(strm, res.content_provider_, - is_shutting_down) < 0) { - return false; - } + if (detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor) < 0) { + return false; + } + } else { + if (detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down) < 0) { + return false; + } + } } - } - return true; + return true; } inline bool Server::read_content(Stream &strm, Request &req, Response &res) { - MultipartFormDataMap::iterator cur; - if (read_content_core( - strm, req, res, - // Regular - [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } - req.body.append(buf, n); - return true; - }, - // Multipart - [&](const MultipartFormData &file) { - cur = req.files.emplace(file.name, file); - return true; - }, - [&](const char *buf, size_t n) { - auto &content = cur->second.content; - if (content.size() + n > content.max_size()) { return false; } - content.append(buf, n); - return true; - })) { - const auto &content_type = req.get_header_value("Content-Type"); - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); + MultipartFormDataMap::iterator cur; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + return true; } - return true; - } - return false; + return false; } inline bool Server::read_content_with_content_receiver( Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) { - return read_content_core(strm, req, res, receiver, multipart_header, - multipart_receiver); + return read_content_core(strm, req, res, receiver, multipart_header, + multipart_receiver); } inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader mulitpart_header, ContentReceiver multipart_receiver) { - detail::MultipartFormDataParser multipart_form_data_parser; - ContentReceiver out; + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } - if (req.is_multipart_form_data()) { - const auto &content_type = req.get_header_value("Content-Type"); - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary)) { - res.status = 400; - return false; - } - - multipart_form_data_parser.set_boundary(std::move(boundary)); - out = [&](const char *buf, size_t n) { - /* For debug - size_t pos = 0; - while (pos < n) { - auto read_size = std::min(1, n - pos); - auto ret = multipart_form_data_parser.parse( - buf + pos, read_size, multipart_receiver, mulitpart_header); - if (!ret) { return false; } - pos += read_size; - } - return true; - */ - return multipart_form_data_parser.parse(buf, n, multipart_receiver, - mulitpart_header); - }; - } else { - out = receiver; - } + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = std::min(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, mulitpart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { - return true; - } + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } - if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, - out, true)) { - return false; - } + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } - if (req.is_multipart_form_data()) { - if (!multipart_form_data_parser.is_valid()) { - res.status = 400; - return false; + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } } - } - return true; + return true; } inline bool Server::handle_file_request(Request &req, Response &res, bool head) { - for (const auto &kv : base_dirs_) { - const auto &mount_point = kv.first; - const auto &base_dir = kv.second; - - // Prefix match - if (!req.path.compare(0, mount_point.size(), mount_point)) { - std::string sub_path = "/" + req.path.substr(mount_point.size()); - if (detail::is_valid_path(sub_path)) { - auto path = base_dir + sub_path; - if (path.back() == '/') { path += "index.html"; } - - if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = - detail::find_content_type(path, file_extension_and_mimetype_map_); - if (type) { res.set_header("Content-Type", type); } - res.status = 200; - if (!head && file_request_handler_) { - file_request_handler_(req, res); - } - return true; - } - } - } - } - return false; + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.compare(0, mount_point.size(), mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { + res.set_header("Content-Type", type); + } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; } inline socket_t Server::create_server_socket(const char *host, int port, int socket_flags, SocketOptions socket_options) const { - return detail::create_socket( - host, port, socket_flags, tcp_nodelay_, socket_options, - [](socket_t sock, struct addrinfo &ai) -> bool { - if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }); + return detail::create_socket( + host, port, socket_flags, tcp_nodelay_, socket_options, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }); } inline int Server::bind_internal(const char *host, int port, int socket_flags) { - if (!is_valid()) { return -1; } - - svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); - if (svr_sock_ == INVALID_SOCKET) { return -1; } + if (!is_valid()) { + return -1; + } - if (port == 0) { - struct sockaddr_storage addr; - socklen_t addr_len = sizeof(addr); - if (getsockname(svr_sock_, reinterpret_cast(&addr), - &addr_len) == -1) { - return -1; + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; } - if (addr.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&addr)->sin_port); - } else if (addr.ss_family == AF_INET6) { - return ntohs(reinterpret_cast(&addr)->sin6_port); + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } } else { - return -1; + return port; } - } else { - return port; - } } inline bool Server::listen_internal() { - auto ret = true; - is_running_ = true; + auto ret = true; + is_running_ = true; - { - std::unique_ptr task_queue(new_task_queue()); + { + std::unique_ptr task_queue(new_task_queue()); - while (svr_sock_ != INVALID_SOCKET) { + while (svr_sock_ != INVALID_SOCKET) { #ifdef __linux__ - if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { #endif - auto val = detail::select_read(svr_sock_, idle_interval_sec_, - idle_interval_usec_); - if (val == 0) { // Timeout - task_queue->on_idle(); - continue; - } + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } #ifdef __linux__ - } + } #endif - socket_t sock = accept(svr_sock_, nullptr, nullptr); - - if (sock == INVALID_SOCKET) { - if (errno == EMFILE) { - // The per-process limit of open file descriptors has been reached. - // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; - } + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } #if __cplusplus > 201703L - task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); #else - task_queue->enqueue([=]() { process_and_close_socket(sock); }); + task_queue->enqueue([=]() { process_and_close_socket(sock); }); #endif - } + } - task_queue->shutdown(); - } + task_queue->shutdown(); + } - is_running_ = false; - return ret; + is_running_ = false; + return ret; } inline bool Server::routing(Request &req, Response &res, Stream &strm) { - // File handler - bool is_head_request = req.method == "HEAD"; - if ((req.method == "GET" || is_head_request) && - handle_file_request(req, res, is_head_request)) { - return true; - } + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } - if (detail::expect_content(req)) { - // Content reader handler - { - ContentReader reader( - [&](ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, receiver, - nullptr, nullptr); - }, - [&](MultipartContentHeader header, ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, nullptr, - header, receiver); - }); - - if (req.method == "POST") { - if (dispatch_request_for_content_reader( - req, res, reader, post_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PUT") { - if (dispatch_request_for_content_reader( - req, res, reader, put_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "PATCH") { - if (dispatch_request_for_content_reader( - req, res, reader, patch_handlers_for_content_reader_)) { - return true; - } - } else if (req.method == "DELETE") { - if (dispatch_request_for_content_reader( - req, res, reader, delete_handlers_for_content_reader_)) { - return true; - } - } - } - - // Read content into `req.body` - if (!read_content(strm, req, res)) { return false; } - } - - // Regular handler - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } else if (req.method == "PATCH") { - return dispatch_request(req, res, patch_handlers_); - } - - res.status = 400; - return false; + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, receiver, + nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, reader, delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; } inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { - try { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + try { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; - } + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } catch (const std::exception &ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } - } catch (const std::exception &ex) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", ex.what()); - } catch (...) { - res.status = 500; - res.set_header("EXCEPTION_WHAT", "UNKNOWN"); - } - return false; + return false; } inline bool Server::dispatch_request_for_content_reader( Request &req, Response &res, ContentReader content_reader, const HandlersForContentReader &handlers) { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res, content_reader); - return true; + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } } - } - return false; + return false; } inline bool Server::process_request(Stream &strm, bool close_connection, bool &connection_closed, const std::function &setup_request) { - std::array buf{}; + std::array buf{}; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - // Connection has been closed on client - if (!line_reader.getline()) { return false; } + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } - Request req; - Response res; + Request req; + Response res; - res.version = "HTTP/1.1"; + res.version = "HTTP/1.1"; - // Check if the request URI doesn't exceed the limit - if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = 414; - return write_response(strm, close_connection, req, res); - } + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, close_connection, req, res); + } - // Request line and headers - if (!parse_request_line(line_reader.ptr(), req) || - !detail::read_headers(strm, req.headers)) { - res.status = 400; - return write_response(strm, close_connection, req, res); - } + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, close_connection, req, res); + } - if (req.get_header_value("Connection") == "close") { - connection_closed = true; - } + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } - if (req.version == "HTTP/1.0" && - req.get_header_value("Connection") != "Keep-Alive") { - connection_closed = true; - } + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } - strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); - req.set_header("REMOTE_ADDR", req.remote_addr); - req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); - if (req.has_header("Range")) { - const auto &range_header_value = req.get_header_value("Range"); - if (!detail::parse_range_header(range_header_value, req.ranges)) { - // TODO: error + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } } - } - - if (setup_request) { setup_request(req); } - if (req.get_header_value("Expect") == "100-continue") { - auto status = 100; - if (expect_100_continue_handler_) { - status = expect_100_continue_handler_(req, res); + if (setup_request) { + setup_request(req); } - switch (status) { - case 100: - case 417: - strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, - detail::status_message(status)); - break; - default: return write_response(strm, close_connection, req, res); + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: + return write_response(strm, close_connection, req, res); + } } - } - // Rounting - if (routing(req, res, strm)) { - if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } - } else { - if (res.status == -1) { res.status = 404; } - } + // Rounting + if (routing(req, res, strm)) { + if (res.status == -1) { + res.status = req.ranges.empty() ? 200 : 206; + } + } else { + if (res.status == -1) { + res.status = 404; + } + } - return write_response(strm, close_connection, req, res); + return write_response(strm, close_connection, req, res); } -inline bool Server::is_valid() const { return true; } +inline bool Server::is_valid() const { + return true; +} inline bool Server::process_and_close_socket(socket_t sock) { - auto ret = detail::process_server_socket( - sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - nullptr); - }); + auto ret = detail::process_server_socket( + sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + nullptr); + }); - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; } // HTTP client implementation inline ClientImpl::ClientImpl(const std::string &host) - : ClientImpl(host, 80, std::string(), std::string()) {} + : ClientImpl(host, 80, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port) - : ClientImpl(host, port, std::string(), std::string()) {} + : ClientImpl(host, port, std::string(), std::string()) { +} inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), - client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + client_cert_path_(client_cert_path), client_key_path_(client_key_path) { +} -inline ClientImpl::~ClientImpl() { stop_core(); } +inline ClientImpl::~ClientImpl() { + stop_core(); +} -inline bool ClientImpl::is_valid() const { return true; } +inline bool ClientImpl::is_valid() const { + return true; +} -inline Error ClientImpl::get_last_error() const { return error_; } +inline Error ClientImpl::get_last_error() const { + return error_; +} inline socket_t ClientImpl::create_client_socket() const { - if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, interface_, error_); + } return detail::create_client_socket( - proxy_host_.c_str(), proxy_port_, tcp_nodelay_, socket_options_, + host_.c_str(), port_, tcp_nodelay_, socket_options_, connection_timeout_sec_, connection_timeout_usec_, interface_, error_); - } - return detail::create_client_socket( - host_.c_str(), port_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, interface_, error_); } inline bool ClientImpl::create_and_connect_socket(Socket &socket) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } - socket.sock = sock; - return true; + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { + return false; + } + socket.sock = sock; + return true; } inline void ClientImpl::close_socket(Socket &socket, bool /*process_socket_ret*/) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - socket_.ssl = nullptr; + socket_.ssl = nullptr; #endif } inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { - std::array buf; + std::array buf; - detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - if (!line_reader.getline()) { return false; } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - - std::cmatch m; - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); - res.reason = std::string(m[3]); + if (!line_reader.getline()) { + return false; + } - // Ignore '100 Continue' - while (res.status == 100) { - if (!line_reader.getline()) { return false; } // CRLF - if (!line_reader.getline()) { return false; } // next response line + const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); - if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); res.reason = std::string(m[3]); - } - return true; + // Ignore '100 Continue' + while (res.status == 100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; } inline bool ClientImpl::send(const Request &req, Response &res) { - std::lock_guard request_mutex_guard(request_mutex_); + std::lock_guard request_mutex_guard(request_mutex_); - { - std::lock_guard guard(socket_mutex_); + { + std::lock_guard guard(socket_mutex_); - auto is_alive = false; - if (socket_.is_open()) { - is_alive = detail::select_write(socket_.sock, 0, 0) > 0; - if (!is_alive) { close_socket(socket_, false); } - } + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::select_write(socket_.sock, 0, 0) > 0; + if (!is_alive) { + close_socket(socket_, false); + } + } - if (!is_alive) { - if (!create_and_connect_socket(socket_)) { return false; } + if (!is_alive) { + if (!create_and_connect_socket(socket_)) { + return false; + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // TODO: refactoring - if (is_ssl()) { - auto &scli = static_cast(*this); - if (!proxy_host_.empty() && proxy_port_ != -1) { - bool success = false; - if (!scli.connect_with_proxy(socket_, res, success)) { - return success; - } - } - - if (!scli.initialize_ssl(socket_)) { return false; } - } + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + bool success = false; + if (!scli.connect_with_proxy(socket_, res, success)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_)) { + return false; + } + } #endif + } } - } - auto close_connection = !keep_alive_; + auto close_connection = !keep_alive_; - auto ret = process_socket(socket_, [&](Stream &strm) { - return handle_request(strm, req, res, close_connection); - }); + auto ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection); + }); - if (close_connection || !ret) { stop_core(); } + if (close_connection || !ret) { + stop_core(); + } - if (!ret) { - if (error_ == Error::Success) { error_ = Error::Unknown; } - } + if (!ret) { + if (error_ == Error::Success) { + error_ = Error::Unknown; + } + } - return ret; + return ret; } inline bool ClientImpl::handle_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - if (req.path.empty()) { - error_ = Error::Connection; - return false; - } + if (req.path.empty()) { + error_ = Error::Connection; + return false; + } - bool ret; + bool ret; - if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { - auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; - ret = process_request(strm, req2, res, close_connection); - } else { - ret = process_request(strm, req, res, close_connection); - } + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection); + } else { + ret = process_request(strm, req, res, close_connection); + } - if (!ret) { return false; } + if (!ret) { + return false; + } - if (300 < res.status && res.status < 400 && follow_location_) { - ret = redirect(req, res); - } + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if ((res.status == 401 || res.status == 407) && - req.authorization_count_ < 5) { - auto is_proxy = res.status == 407; - const auto &username = - is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; - const auto &password = - is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; - - if (!username.empty() && !password.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res, auth, is_proxy)) { - Request new_req = req; - new_req.authorization_count_ += 1; - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - new_req.headers.erase(key); - new_req.headers.insert(detail::make_digest_authentication_header( - req, auth, new_req.authorization_count_, detail::random_string(10), - username, password, is_proxy)); - - Response new_res; - - ret = send(new_req, new_res); - if (ret) { res = new_res; } - } - } - } + if ((res.status == 401 || res.status == 407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { + res = new_res; + } + } + } + } #endif - return ret; + return ret; } inline bool ClientImpl::redirect(const Request &req, Response &res) { - if (req.redirect_count == 0) { - error_ = Error::ExceedRedirectCount; - return false; - } + if (req.redirect_count == 0) { + error_ = Error::ExceedRedirectCount; + return false; + } - auto location = detail::decode_url(res.get_header_value("location"), true); - if (location.empty()) { return false; } + auto location = detail::decode_url(res.get_header_value("location"), true); + if (location.empty()) { + return false; + } - const static std::regex re( - R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + const static std::regex re( + R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } - auto scheme = is_ssl() ? "https" : "http"; + auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - auto port_str = m[3].str(); - auto next_path = m[4].str(); + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); - auto next_port = port_; - if (!port_str.empty()) { - next_port = std::stoi(port_str); - } else if (!next_scheme.empty()) { - next_port = next_scheme == "https" ? 443 : 80; - } + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } - if (next_scheme.empty()) { next_scheme = scheme; } - if (next_host.empty()) { next_host = host_; } - if (next_path.empty()) { next_path = "/"; } + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } - if (next_scheme == scheme && next_host == host_ && next_port == port_) { - return detail::redirect(*this, req, res, next_path); - } else { - if (next_scheme == "https") { + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSLClient cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; #else - return false; + return false; #endif - } else { - ClientImpl cli(next_host.c_str(), next_port); - cli.copy_settings(*this); - auto ret = detail::redirect(cli, req, res, next_path); - if (!ret) { error_ = cli.get_last_error(); } - return ret; + } else { + ClientImpl cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + auto ret = detail::redirect(cli, req, res, next_path); + if (!ret) { + error_ = cli.get_last_error(); + } + return ret; + } } - } } inline bool ClientImpl::write_request(Stream &strm, const Request &req, bool close_connection) { - detail::BufferStream bstrm; - - // Request line - const auto &path = detail::encode_url(req.path); + detail::BufferStream bstrm; - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + // Request line + const auto &path = detail::encode_url(req.path); - // Additonal headers - Headers headers; - if (close_connection) { headers.emplace("Connection", "close"); } + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - if (!req.has_header("Host")) { - if (is_ssl()) { - if (port_ == 443) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } - } else { - if (port_ == 80) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } + // Additonal headers + Headers headers; + if (close_connection) { + headers.emplace("Connection", "close"); } - } - - if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } - if (!req.has_header("User-Agent")) { - headers.emplace("User-Agent", "cpp-httplib/0.7"); - } - - if (req.body.empty()) { - if (req.content_provider) { - auto length = std::to_string(req.content_length); - headers.emplace("Content-Length", length); - } else { - headers.emplace("Content-Length", "0"); - } - } else { - if (!req.has_header("Content-Type")) { - headers.emplace("Content-Type", "text/plain"); + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } } - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - headers.emplace("Content-Length", length); + if (!req.has_header("Accept")) { + headers.emplace("Accept", "*/*"); } - } - if (!basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - basic_auth_username_, basic_auth_password_, false)); - } + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.7"); + } - if (!proxy_basic_auth_username_.empty() && - !proxy_basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( - proxy_basic_auth_username_, proxy_basic_auth_password_, true)); - } + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } - if (!bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - bearer_token_auth_token_, false)); - } + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } - if (!proxy_bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( - proxy_bearer_token_auth_token_, true)); - } + if (!basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } - detail::write_headers(bstrm, req, headers); + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } - // Flush buffer - auto &data = bstrm.get_buffer(); - if (!detail::write_data(strm, data.data(), data.size())) { - error_ = Error::Write; - return false; - } + if (!bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } - // Body - if (req.body.empty()) { - if (req.content_provider) { - size_t offset = 0; - size_t end_offset = req.content_length; + if (!proxy_bearer_token_auth_token_.empty()) { + headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } - bool ok = true; + detail::write_headers(bstrm, req, headers); - DataSink data_sink; - data_sink.write = [&](const char *d, size_t l) { - if (ok) { - if (detail::write_data(strm, d, l)) { - offset += l; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error_ = Error::Write; + return false; + } - while (offset < end_offset) { - if (!req.content_provider(offset, end_offset - offset, data_sink)) { - error_ = Error::Canceled; - return false; - } - if (!ok) { - error_ = Error::Write; - return false; + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + bool ok = true; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + if (detail::write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (offset < end_offset) { + if (!req.content_provider(offset, end_offset - offset, data_sink)) { + error_ = Error::Canceled; + return false; + } + if (!ok) { + error_ = Error::Write; + return false; + } + } } - } + } else { + return detail::write_data(strm, req.body.data(), req.body.size()); } - } else { - return detail::write_data(strm, req.body.data(), req.body.size()); - } - return true; + return true; } inline std::shared_ptr ClientImpl::send_with_content_provider( @@ -4956,541 +5365,572 @@ inline std::shared_ptr ClientImpl::send_with_content_provider( const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type) { - Request req; - req.method = method; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = method; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - if (content_type) { req.headers.emplace("Content-Type", content_type); } + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } #ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_) { - detail::gzip_compressor compressor; - - if (content_provider) { - auto ok = true; - size_t offset = 0; - - DataSink data_sink; - data_sink.write = [&](const char *data, size_t data_len) { - if (ok) { - auto last = offset + data_len == content_length; - - auto ret = compressor.compress( - data, data_len, last, [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - }); - - if (ret) { - offset += data_len; - } else { - ok = false; - } - } - }; - data_sink.is_writable = [&](void) { return ok && true; }; - - while (ok && offset < content_length) { - if (!content_provider(offset, content_length - offset, data_sink)) { - error_ = Error::Canceled; - return nullptr; + if (compress_) { + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && true; }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error_ = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body.data(), body.size(), true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + return nullptr; + } } - } - } else { - if (!compressor.compress(body.data(), body.size(), true, - [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - })) { - return nullptr; - } - } - req.headers.emplace("Content-Encoding", "gzip"); - } else + req.headers.emplace("Content-Encoding", "gzip"); + } else #endif - { - if (content_provider) { - req.content_length = content_length; - req.content_provider = content_provider; - } else { - req.body = body; + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } } - } - auto res = std::make_shared(); + auto res = std::make_shared(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? res : nullptr; } inline bool ClientImpl::process_request(Stream &strm, const Request &req, Response &res, bool close_connection) { - // Send request - if (!write_request(strm, req, close_connection)) { return false; } + // Send request + if (!write_request(strm, req, close_connection)) { + return false; + } - // Receive response and headers - if (!read_response_line(strm, res) || - !detail::read_headers(strm, res.headers)) { - error_ = Error::Read; - return false; - } + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + error_ = Error::Read; + return false; + } - if (req.response_handler) { - if (!req.response_handler(res)) { - error_ = Error::Canceled; - return false; + if (req.response_handler) { + if (!req.response_handler(res)) { + error_ = Error::Canceled; + return false; + } } - } - // Body - if (req.method != "HEAD" && req.method != "CONNECT") { - auto out = - req.content_receiver - ? static_cast([&](const char *buf, size_t n) { + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + auto out = + req.content_receiver ? static_cast([&](const char *buf, size_t n) { auto ret = req.content_receiver(buf, n); - if (!ret) { error_ = Error::Canceled; } + if (!ret) { + error_ = Error::Canceled; + } return ret; - }) - : static_cast([&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { return false; } - res.body.append(buf, n); + }) : + static_cast([&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress) { return true; - }); - - auto progress = [&](uint64_t current, uint64_t total) { - if (!req.progress) { return true; } - auto ret = req.progress(current, total); - if (!ret) { error_ = Error::Canceled; } - return ret; - }; + } + auto ret = req.progress(current, total); + if (!ret) { + error_ = Error::Canceled; + } + return ret; + }; - int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, progress, out, decompress_)) { - if (error_ != Error::Canceled) { error_ = Error::Read; } - return false; + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, progress, out, decompress_)) { + if (error_ != Error::Canceled) { + error_ = Error::Read; + } + return false; + } } - } - if (res.get_header_value("Connection") == "close" || - (res.version == "HTTP/1.0" && res.reason != "Connection established")) { - stop_core(); - } + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + stop_core(); + } - // Log - if (logger_) { logger_(req, res); } + // Log + if (logger_) { + logger_(req, res); + } - return true; + return true; } inline bool ClientImpl::process_socket(Socket &socket, std::function callback) { - return detail::process_client_socket(socket.sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, callback); + return detail::process_client_socket(socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, callback); } -inline bool ClientImpl::is_ssl() const { return false; } +inline bool ClientImpl::is_ssl() const { + return false; +} inline Result ClientImpl::Get(const char *path) { - return Get(path, Headers(), Progress()); + return Get(path, Headers(), Progress()); } inline Result ClientImpl::Get(const char *path, Progress progress) { - return Get(path, Headers(), std::move(progress)); + return Get(path, Headers(), std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers) { - return Get(path, headers, Progress()); + return Get(path, headers, Progress()); } inline Result ClientImpl::Get(const char *path, const Headers &headers, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver) { - return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return Get(path, headers, nullptr, std::move(content_receiver), nullptr); + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return Get(path, headers, nullptr, std::move(content_receiver), - std::move(progress)); + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - nullptr); + return Get(path, Headers(), std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, headers, std::move(response_handler), content_receiver, - nullptr); + return Get(path, headers, std::move(response_handler), content_receiver, + nullptr); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - progress); + return Get(path, Headers(), std::move(response_handler), content_receiver, + progress); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.response_handler = std::move(response_handler); - req.content_receiver = std::move(content_receiver); - req.progress = std::move(progress); + Request req; + req.method = "GET"; + req.path = path; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Head(const char *path) { - return Head(path, Headers()); + return Head(path, Headers()); } inline Result ClientImpl::Head(const char *path, const Headers &headers) { - Request req; - req.method = "HEAD"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "HEAD"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Post(const char *path) { - return Post(path, std::string(), nullptr); + return Post(path, std::string(), nullptr); } inline Result ClientImpl::Post(const char *path, const std::string &body, const char *content_type) { - return Post(path, Headers(), body, content_type); + return Post(path, Headers(), body, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Params ¶ms) { - return Post(path, Headers(), params); + return Post(path, Headers(), params); } inline Result ClientImpl::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Post(path, Headers(), content_length, content_provider, content_type); + return Post(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Post(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Post(const char *path, const MultipartFormDataItems &items) { - return Post(path, Headers(), items); + return Post(path, Headers(), items); } inline Result ClientImpl::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - auto boundary = detail::make_multipart_data_boundary(); + auto boundary = detail::make_multipart_data_boundary(); - std::string body; + std::string body; - for (const auto &item : items) { - body += "--" + boundary + "\r\n"; - body += "Content-Disposition: form-data; name=\"" + item.name + "\""; - if (!item.filename.empty()) { - body += "; filename=\"" + item.filename + "\""; - } - body += "\r\n"; - if (!item.content_type.empty()) { - body += "Content-Type: " + item.content_type + "\r\n"; + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; } - body += "\r\n"; - body += item.content + "\r\n"; - } - body += "--" + boundary + "--\r\n"; + body += "--" + boundary + "--\r\n"; - std::string content_type = "multipart/form-data; boundary=" + boundary; - return Post(path, headers, body, content_type.c_str()); + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); } inline Result ClientImpl::Put(const char *path) { - return Put(path, std::string(), nullptr); + return Put(path, std::string(), nullptr); } inline Result ClientImpl::Put(const char *path, const std::string &body, const char *content_type) { - return Put(path, Headers(), body, content_type); + return Put(path, Headers(), body, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Put(path, Headers(), content_length, content_provider, content_type); + return Put(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Put(const char *path, const Params ¶ms) { - return Put(path, Headers(), params); + return Put(path, Headers(), params); } inline Result ClientImpl::Put(const char *path, const Headers &headers, const Params ¶ms) { - auto query = detail::params_to_query_str(params); - return Put(path, headers, query, "application/x-www-form-urlencoded"); + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); } inline Result ClientImpl::Patch(const char *path, const std::string &body, const char *content_type) { - return Patch(path, Headers(), body, content_type); + return Patch(path, Headers(), body, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, body, 0, - nullptr, content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, body, 0, + nullptr, content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Patch(path, Headers(), content_length, content_provider, content_type); + return Patch(path, Headers(), content_length, content_provider, content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); + return Result{ret, get_last_error()}; } inline Result ClientImpl::Delete(const char *path) { - return Delete(path, Headers(), std::string(), nullptr); + return Delete(path, Headers(), std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const std::string &body, const char *content_type) { - return Delete(path, Headers(), body, content_type); + return Delete(path, Headers(), body, content_type); } inline Result ClientImpl::Delete(const char *path, const Headers &headers) { - return Delete(path, headers, std::string(), nullptr); + return Delete(path, headers, std::string(), nullptr); } inline Result ClientImpl::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - Request req; - req.method = "DELETE"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; - - if (content_type) { req.headers.emplace("Content-Type", content_type); } - req.body = body; + Request req; + req.method = "DELETE"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; + + if (content_type) { + req.headers.emplace("Content-Type", content_type); + } + req.body = body; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline Result ClientImpl::Options(const char *path) { - return Options(path, Headers()); + return Options(path, Headers()); } inline Result ClientImpl::Options(const char *path, const Headers &headers) { - Request req; - req.method = "OPTIONS"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + Request req; + req.method = "OPTIONS"; + req.headers = default_headers_; + req.headers.insert(headers.begin(), headers.end()); + req.path = path; - auto res = std::make_shared(); - auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + auto res = std::make_shared(); + auto ret = send(req, *res); + return Result{ret ? res : nullptr, get_last_error()}; } inline size_t ClientImpl::is_socket_open() const { - std::lock_guard guard(socket_mutex_); - return socket_.is_open(); + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); } inline void ClientImpl::stop() { - stop_core(); - error_ = Error::Canceled; + stop_core(); + error_ = Error::Canceled; } inline void ClientImpl::stop_core() { - std::lock_guard guard(socket_mutex_); - if (socket_.is_open()) { - detail::shutdown_socket(socket_.sock); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - close_socket(socket_, true); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { + detail::shutdown_socket(socket_.sock); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + close_socket(socket_, true); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } } inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { - connection_timeout_sec_ = sec; - connection_timeout_usec_ = usec; + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; } inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { - read_timeout_sec_ = sec; - read_timeout_usec_ = usec; + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { - write_timeout_sec_ = sec; - write_timeout_usec_ = usec; + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; } inline void ClientImpl::set_basic_auth(const char *username, const char *password) { - basic_auth_username_ = username; - basic_auth_password_ = password; + basic_auth_username_ = username; + basic_auth_password_ = password; } inline void ClientImpl::set_bearer_token_auth(const char *token) { - bearer_token_auth_token_ = token; + bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_digest_auth(const char *username, const char *password) { - digest_auth_username_ = username; - digest_auth_password_ = password; + digest_auth_username_ = username; + digest_auth_password_ = password; } #endif -inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } +inline void ClientImpl::set_keep_alive(bool on) { + keep_alive_ = on; +} -inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } +inline void ClientImpl::set_follow_location(bool on) { + follow_location_ = on; +} inline void ClientImpl::set_default_headers(Headers headers) { - default_headers_ = std::move(headers); + default_headers_ = std::move(headers); } -inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void ClientImpl::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} inline void ClientImpl::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = socket_options; } -inline void ClientImpl::set_compress(bool on) { compress_ = on; } +inline void ClientImpl::set_compress(bool on) { + compress_ = on; +} -inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } +inline void ClientImpl::set_decompress(bool on) { + decompress_ = on; +} -inline void ClientImpl::set_interface(const char *intf) { interface_ = intf; } +inline void ClientImpl::set_interface(const char *intf) { + interface_ = intf; +} inline void ClientImpl::set_proxy(const char *host, int port) { - proxy_host_ = host; - proxy_port_ = port; + proxy_host_ = host; + proxy_port_ = port; } inline void ClientImpl::set_proxy_basic_auth(const char *username, const char *password) { - proxy_basic_auth_username_ = username; - proxy_basic_auth_password_ = password; + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; } inline void ClientImpl::set_proxy_bearer_token_auth(const char *token) { - proxy_bearer_token_auth_token_ = token; + proxy_bearer_token_auth_token_ = token; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void ClientImpl::set_proxy_digest_auth(const char *username, const char *password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } #endif inline void ClientImpl::set_logger(Logger logger) { - logger_ = std::move(logger); + logger_ = std::move(logger); } /* @@ -5499,66 +5939,66 @@ inline void ClientImpl::set_logger(Logger logger) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT namespace detail { -template +template inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } - if (ssl) { - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + if (ssl) { + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - return nullptr; + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + return nullptr; + } } - } - return ssl; + return ssl; } inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, bool process_socket_ret) { - if (process_socket_ret) { - SSL_shutdown(ssl); // shutdown only if not already closed by remote - } + if (process_socket_ret) { + SSL_shutdown(ssl); // shutdown only if not already closed by remote + } - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); } -template +template inline bool process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); + return process_server_socket_core( + sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); } -template +template inline bool process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm); + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); } #if OPENSSL_VERSION_NUMBER < 0x10100000L @@ -5566,49 +6006,51 @@ static std::shared_ptr> openSSL_locks_; class SSLThreadLocks { public: - SSLThreadLocks() { - openSSL_locks_ = - std::make_shared>(CRYPTO_num_locks()); - CRYPTO_set_locking_callback(locking_callback); - } + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } - ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + ~SSLThreadLocks() { + CRYPTO_set_locking_callback(nullptr); + } private: - static void locking_callback(int mode, int type, const char * /*file*/, - int /*line*/) { - auto &lk = (*openSSL_locks_)[static_cast(type)]; - if (mode & CRYPTO_LOCK) { - lk.lock(); - } else { - lk.unlock(); + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &lk = (*openSSL_locks_)[static_cast(type)]; + if (mode & CRYPTO_LOCK) { + lk.lock(); + } else { + lk.unlock(); + } } - } }; #endif class SSLInit { public: - SSLInit() { + SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - SSL_load_error_strings(); - SSL_library_init(); + SSL_load_error_strings(); + SSL_library_init(); #else - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); #endif - } + } - ~SSLInit() { + ~SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - ERR_free_strings(); + ERR_free_strings(); #endif - } + } private: #if OPENSSL_VERSION_NUMBER < 0x10100000L - SSLThreadLocks thread_init_; + SSLThreadLocks thread_init_; #endif }; @@ -5622,839 +6064,904 @@ inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec) { - { - timeval tv; - tv.tv_sec = static_cast(read_timeout_sec); - tv.tv_usec = static_cast(read_timeout_usec); + { + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } - { - timeval tv; - tv.tv_sec = static_cast(write_timeout_sec); - tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } + { + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), - sizeof(tv)); - } + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + } } -inline SSLSocketStream::~SSLSocketStream() {} +inline SSLSocketStream::~SSLSocketStream() { +} inline bool SSLSocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SSLSocketStream::is_writable() const { - return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > - 0; + return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > + 0; } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0 || is_readable()) { - return SSL_read(ssl_, ptr, static_cast(size)); - } - return -1; + if (SSL_pending(ssl_) > 0 || is_readable()) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } - return -1; + if (is_writable()) { + return SSL_write(ssl_, ptr, static_cast(size)); + } + return -1; } inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); + detail::get_remote_ip_and_port(sock_, ip, port); } static SSLInit sslinit_; -} // namespace detail +} // namespace detail // SSL HTTP server implementation inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, const char *client_ca_cert_dir_path) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - // if (client_ca_cert_file_path) { - // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); - // SSL_CTX_set_client_CA_list(ctx_, list); - // } - - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } - } } inline SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLServer::is_valid() const { return ctx_; } +inline bool SSLServer::is_valid() const { + return ctx_; +} inline bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, - [](SSL * /*ssl*/) { return true; }); - - if (ssl) { - auto ret = detail::process_server_socket_ssl( - ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [this, ssl](Stream &strm, bool close_connection, - bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - detail::ssl_delete(ctx_mutex_, ssl, ret); - return ret; - } + auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept, + [](SSL * /*ssl*/) { return true; }); + + if (ssl) { + auto ret = detail::process_server_socket_ssl( + ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [this, ssl](Stream &strm, bool close_connection, + bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + detail::ssl_delete(ctx_mutex_, ssl, ret); + return ret; + } - detail::close_socket(sock); - return false; + detail::close_socket(sock); + return false; } // SSL HTTP client implementation inline SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} + : SSLClient(host, 443, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} + : SSLClient(host, port, std::string(), std::string()) { +} inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key) : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (client_cert != nullptr && client_key != nullptr) { - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } - } } inline SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { + SSL_CTX_free(ctx_); + } } -inline bool SSLClient::is_valid() const { return ctx_; } +inline bool SSLClient::is_valid() const { + return ctx_; +} inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } - if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } + if (ca_cert_file_path) { + ca_cert_file_path_ = ca_cert_file_path; + } + if (ca_cert_dir_path) { + ca_cert_dir_path_ = ca_cert_dir_path; + } } inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { ca_cert_store_ = ca_cert_store; } + if (ca_cert_store) { + ca_cert_store_ = ca_cert_store; + } } inline void SSLClient::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; + server_certificate_verification_ = enabled; } inline long SSLClient::get_openssl_verify_result() const { - return verify_result_; + return verify_result_; } -inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } +inline SSL_CTX *SSLClient::ssl_context() const { + return ctx_; +} inline bool SSLClient::create_and_connect_socket(Socket &socket) { - return is_valid() && ClientImpl::create_and_connect_socket(socket); + return is_valid() && ClientImpl::create_and_connect_socket(socket); } inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success) { - success = true; - Response res2; - - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = host_and_port_; - return process_request(strm, req2, res2, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - - if (res2.status == 407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(res2, auth, true)) { - Response res3; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = host_and_port_; - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - return process_request(strm, req3, res3, false); - })) { - close_socket(socket, true); - success = false; - return false; - } - } - } else { - res = res2; - return false; + success = true; + Response res2; + + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false); + })) { + close_socket(socket, true); + success = false; + return false; + } + } + } else { + res = res2; + return false; + } } - } - return true; + return true; } inline bool SSLClient::load_certs() { - bool ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - ret = false; - } - } else if (ca_cert_store_ != nullptr) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { - SSL_CTX_set_cert_store(ctx_, ca_cert_store_); - } - } else { + bool ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else if (ca_cert_store_ != nullptr) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { + SSL_CTX_set_cert_store(ctx_, ca_cert_store_); + } + } else { #ifdef _WIN32 - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); #else SSL_CTX_set_default_verify_paths(ctx_); #endif - } - }); + } + }); - return ret; + return ret; } inline bool SSLClient::initialize_ssl(Socket &socket) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl) { - if (server_certificate_verification_) { - if (!load_certs()) { - error_ = Error::SSLLoadingCerts; - return false; - } - SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); - } + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (server_certificate_verification_) { + if (!load_certs()) { + error_ = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl, SSL_VERIFY_NONE, nullptr); + } - if (SSL_connect(ssl) != 1) { - error_ = Error::SSLConnection; - return false; - } + if (SSL_connect(ssl) != 1) { + error_ = Error::SSLConnection; + return false; + } - if (server_certificate_verification_) { - verify_result_ = SSL_get_verify_result(ssl); + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); - if (verify_result_ != X509_V_OK) { - error_ = Error::SSLServerVerification; - return false; - } + if (verify_result_ != X509_V_OK) { + error_ = Error::SSLServerVerification; + return false; + } - auto server_cert = SSL_get_peer_certificate(ssl); + auto server_cert = SSL_get_peer_certificate(ssl); - if (server_cert == nullptr) { - error_ = Error::SSLServerVerification; - return false; - } + if (server_cert == nullptr) { + error_ = Error::SSLServerVerification; + return false; + } - if (!verify_host(server_cert)) { - X509_free(server_cert); - error_ = Error::SSLServerVerification; - return false; - } - X509_free(server_cert); - } + if (!verify_host(server_cert)) { + X509_free(server_cert); + error_ = Error::SSLServerVerification; + return false; + } + X509_free(server_cert); + } - return true; - }, - [&](SSL *ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - return true; - }); + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }); - if (ssl) { - socket.ssl = ssl; - return true; - } + if (ssl) { + socket.ssl = ssl; + return true; + } - close_socket(socket, false); - return false; + close_socket(socket, false); + return false; } inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); - socket_.ssl = nullptr; - } + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); + socket_.ssl = nullptr; + } } inline bool SSLClient::process_socket(Socket &socket, std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, callback); + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, callback); } -inline bool SSLClient::is_ssl() const { return true; } +inline bool SSLClient::is_ssl() const { + return true; +} inline bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" + /* Quote from RFC2818 section 3.1 "Server Identity" - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); } inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; + auto ret = false; - auto type = GEN_DNS; + auto type = GEN_DNS; - struct in6_addr addr6; - struct in_addr addr; - size_t addr_len = 0; + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; #ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); - } + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_mached = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (val->type == type) { - auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); - auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); - - if (strlen(name) == name_len) { - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || - !memcmp(&addr, name, addr_len)) { - ip_mached = true; + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } } - break; - } } - } - } - if (dsn_matched || ip_mached) { ret = true; } - } + if (dsn_matched || ip_mached) { + ret = true; + } + } - GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); - return ret; + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + return ret; } inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); + const auto subject_name = X509_get_subject_name(server_cert); - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } } - } - return false; + return false; } inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { - if (host_.size() == pattern_len && host_ == pattern) { return true; } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(std::string(b, e)); - }); + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); - if (host_components_.size() != pattern_components.size()) { return false; } + if (host_components_.size() != pattern_components.size()) { + return false; + } - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (p != h && p != "*") { - auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && - !p.compare(0, p.size() - 1, h)); - if (!partial_match) { return false; } + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; } - ++itr; - } - return true; + return true; } #endif // Universal client implementation inline Client::Client(const char *scheme_host_port) - : Client(scheme_host_port, std::string(), std::string()) {} + : Client(scheme_host_port, std::string(), std::string()) { +} inline Client::Client(const char *scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); + const static std::regex re(R"(^(?:([a-z]+)://)?([^:/?#]+)(?::(\d+))?)"); - std::cmatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); + std::cmatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (!scheme.empty() && (scheme != "http" && scheme != "https")) { + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else - if (!scheme.empty() && scheme != "http") { + if (!scheme.empty() && scheme != "http") { #endif - return; - } + return; + } - auto is_ssl = scheme == "https"; + auto is_ssl = scheme == "https"; - auto host = m[2].str(); + auto host = m[2].str(); - auto port_str = m[3].str(); - auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + auto port_str = m[3].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); - if (is_ssl) { + if (is_ssl) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); - is_ssl_ = is_ssl; + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; #endif + } else { + cli_ = std::make_shared(host.c_str(), port, client_cert_path, + client_key_path); + } } else { - cli_ = std::make_shared(host.c_str(), port, client_cert_path, - client_key_path); + cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, + client_key_path); } - } else { - cli_ = std::make_shared(scheme_host_port, 80, client_cert_path, - client_key_path); - } } inline Client::Client(const std::string &host, int port) - : cli_(std::make_shared(host, port)) {} + : cli_(std::make_shared(host, port)) { +} inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : cli_(std::make_shared(host, port, client_cert_path, - client_key_path)) {} + client_key_path)) { +} -inline Client::~Client() {} +inline Client::~Client() { +} inline bool Client::is_valid() const { - return cli_ != nullptr && cli_->is_valid(); + return cli_ != nullptr && cli_->is_valid(); } -inline Result Client::Get(const char *path) { return cli_->Get(path); } +inline Result Client::Get(const char *path) { + return cli_->Get(path); +} inline Result Client::Get(const char *path, const Headers &headers) { - return cli_->Get(path, headers); + return cli_->Get(path, headers); } inline Result Client::Get(const char *path, Progress progress) { - return cli_->Get(path, progress); + return cli_->Get(path, progress); } inline Result Client::Get(const char *path, const Headers &headers, Progress progress) { - return cli_->Get(path, headers, progress); + return cli_->Get(path, headers, progress); } inline Result Client::Get(const char *path, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(content_receiver)); + return cli_->Get(path, std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(content_receiver)); + return cli_->Get(path, headers, std::move(content_receiver)); } inline Result Client::Get(const char *path, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, std::move(content_receiver), - std::move(progress)); + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return cli_->Get(path, headers, std::move(response_handler), - std::move(content_receiver)); + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); } inline Result Client::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, std::move(response_handler), - std::move(content_receiver), std::move(progress)); + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, response_handler, content_receiver, progress); + return cli_->Get(path, headers, response_handler, content_receiver, progress); } -inline Result Client::Head(const char *path) { return cli_->Head(path); } +inline Result Client::Head(const char *path) { + return cli_->Head(path); +} inline Result Client::Head(const char *path, const Headers &headers) { - return cli_->Head(path, headers); + return cli_->Head(path, headers); } -inline Result Client::Post(const char *path) { return cli_->Post(path); } +inline Result Client::Post(const char *path) { + return cli_->Post(path); +} inline Result Client::Post(const char *path, const std::string &body, const char *content_type) { - return cli_->Post(path, body, content_type); + return cli_->Post(path, body, content_type); } inline Result Client::Post(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Post(path, headers, body, content_type); + return cli_->Post(path, headers, body, content_type); } inline Result Client::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, content_length, content_provider, content_type); + return cli_->Post(path, content_length, content_provider, content_type); } inline Result Client::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, headers, content_length, content_provider, - content_type); + return cli_->Post(path, headers, content_length, content_provider, + content_type); } inline Result Client::Post(const char *path, const Params ¶ms) { - return cli_->Post(path, params); + return cli_->Post(path, params); } inline Result Client::Post(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Post(path, headers, params); + return cli_->Post(path, headers, params); } inline Result Client::Post(const char *path, const MultipartFormDataItems &items) { - return cli_->Post(path, items); + return cli_->Post(path, items); } inline Result Client::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - return cli_->Post(path, headers, items); + return cli_->Post(path, headers, items); +} +inline Result Client::Put(const char *path) { + return cli_->Put(path); } -inline Result Client::Put(const char *path) { return cli_->Put(path); } inline Result Client::Put(const char *path, const std::string &body, const char *content_type) { - return cli_->Put(path, body, content_type); + return cli_->Put(path, body, content_type); } inline Result Client::Put(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Put(path, headers, body, content_type); + return cli_->Put(path, headers, body, content_type); } inline Result Client::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, content_length, content_provider, content_type); + return cli_->Put(path, content_length, content_provider, content_type); } inline Result Client::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, headers, content_length, content_provider, - content_type); + return cli_->Put(path, headers, content_length, content_provider, + content_type); } inline Result Client::Put(const char *path, const Params ¶ms) { - return cli_->Put(path, params); + return cli_->Put(path, params); } inline Result Client::Put(const char *path, const Headers &headers, const Params ¶ms) { - return cli_->Put(path, headers, params); + return cli_->Put(path, headers, params); } inline Result Client::Patch(const char *path, const std::string &body, const char *content_type) { - return cli_->Patch(path, body, content_type); + return cli_->Patch(path, body, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Patch(path, headers, body, content_type); + return cli_->Patch(path, headers, body, content_type); } inline Result Client::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, content_length, content_provider, content_type); + return cli_->Patch(path, content_length, content_provider, content_type); } inline Result Client::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, headers, content_length, content_provider, - content_type); + return cli_->Patch(path, headers, content_length, content_provider, + content_type); +} +inline Result Client::Delete(const char *path) { + return cli_->Delete(path); } -inline Result Client::Delete(const char *path) { return cli_->Delete(path); } inline Result Client::Delete(const char *path, const std::string &body, const char *content_type) { - return cli_->Delete(path, body, content_type); + return cli_->Delete(path, body, content_type); } inline Result Client::Delete(const char *path, const Headers &headers) { - return cli_->Delete(path, headers); + return cli_->Delete(path, headers); } inline Result Client::Delete(const char *path, const Headers &headers, const std::string &body, const char *content_type) { - return cli_->Delete(path, headers, body, content_type); + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Options(const char *path) { + return cli_->Options(path); } -inline Result Client::Options(const char *path) { return cli_->Options(path); } inline Result Client::Options(const char *path, const Headers &headers) { - return cli_->Options(path, headers); + return cli_->Options(path, headers); } inline bool Client::send(const Request &req, Response &res) { - return cli_->send(req, res); + return cli_->send(req, res); } -inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } +inline size_t Client::is_socket_open() const { + return cli_->is_socket_open(); +} -inline void Client::stop() { cli_->stop(); } +inline void Client::stop() { + cli_->stop(); +} inline void Client::set_default_headers(Headers headers) { - cli_->set_default_headers(std::move(headers)); + cli_->set_default_headers(std::move(headers)); } -inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } +inline void Client::set_tcp_nodelay(bool on) { + cli_->set_tcp_nodelay(on); +} inline void Client::set_socket_options(SocketOptions socket_options) { - cli_->set_socket_options(socket_options); + cli_->set_socket_options(socket_options); } inline void Client::set_connection_timeout(time_t sec, time_t usec) { - cli_->set_connection_timeout(sec, usec); + cli_->set_connection_timeout(sec, usec); } inline void Client::set_read_timeout(time_t sec, time_t usec) { - cli_->set_read_timeout(sec, usec); + cli_->set_read_timeout(sec, usec); } inline void Client::set_write_timeout(time_t sec, time_t usec) { - cli_->set_write_timeout(sec, usec); + cli_->set_write_timeout(sec, usec); } inline void Client::set_basic_auth(const char *username, const char *password) { - cli_->set_basic_auth(username, password); + cli_->set_basic_auth(username, password); } inline void Client::set_bearer_token_auth(const char *token) { - cli_->set_bearer_token_auth(token); + cli_->set_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_digest_auth(const char *username, const char *password) { - cli_->set_digest_auth(username, password); + cli_->set_digest_auth(username, password); } #endif -inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_keep_alive(bool on) { + cli_->set_keep_alive(on); +} inline void Client::set_follow_location(bool on) { - cli_->set_follow_location(on); + cli_->set_follow_location(on); } -inline void Client::set_compress(bool on) { cli_->set_compress(on); } +inline void Client::set_compress(bool on) { + cli_->set_compress(on); +} -inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } +inline void Client::set_decompress(bool on) { + cli_->set_decompress(on); +} inline void Client::set_interface(const char *intf) { - cli_->set_interface(intf); + cli_->set_interface(intf); } inline void Client::set_proxy(const char *host, int port) { - cli_->set_proxy(host, port); + cli_->set_proxy(host, port); } inline void Client::set_proxy_basic_auth(const char *username, const char *password) { - cli_->set_proxy_basic_auth(username, password); + cli_->set_proxy_basic_auth(username, password); } inline void Client::set_proxy_bearer_token_auth(const char *token) { - cli_->set_proxy_bearer_token_auth(token); + cli_->set_proxy_bearer_token_auth(token); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline void Client::set_proxy_digest_auth(const char *username, const char *password) { - cli_->set_proxy_digest_auth(username, password); + cli_->set_proxy_digest_auth(username, password); } #endif -inline void Client::set_logger(Logger logger) { cli_->set_logger(logger); } +inline void Client::set_logger(Logger logger) { + cli_->set_logger(logger); +} #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline Client &Client::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, - ca_cert_dir_path); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_path(ca_cert_file_path, + ca_cert_dir_path); + } + return *this; } inline Client &Client::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (is_ssl_) { - static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } + return *this; } inline Client &Client::enable_server_certificate_verification(bool enabled) { - if (is_ssl_) { - static_cast(*cli_).enable_server_certificate_verification( - enabled); - } - return *this; + if (is_ssl_) { + static_cast(*cli_).enable_server_certificate_verification( + enabled); + } + return *this; } inline long Client::get_openssl_verify_result() const { - if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); - } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } inline SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } - return nullptr; + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; } #endif // ---------------------------------------------------------------------------- -} // namespace httplib +} // namespace httplib -#endif // CPPHTTPLIB_HTTPLIB_H +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/src/bb/image-io/rt.h b/src/bb/image-io/rt.h index 98f2fbb5..703c74d8 100644 --- a/src/bb/image-io/rt.h +++ b/src/bb/image-io/rt.h @@ -16,15 +16,15 @@ namespace image_io { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // image_io -} // bb -} // ion +} // namespace image_io +} // namespace bb +} // namespace ion #define ION_REGISTER_EXTERN(NAME) static auto ion_register_extern_##NAME = ion::bb::image_io::RegisterExtern(#NAME, NAME); #include "rt_u3v.h" @@ -36,7 +36,6 @@ class RegisterExtern { #include "rt_v4l2.h" #endif - #undef ION_REGISTER_EXTERN #endif diff --git a/src/bb/image-io/rt_common.h b/src/bb/image-io/rt_common.h index c0475f89..8a730b43 100644 --- a/src/bb/image-io/rt_common.h +++ b/src/bb/image-io/rt_common.h @@ -17,7 +17,6 @@ #include "ion/export.h" - #include "log.h" #include "httplib.h" @@ -38,7 +37,7 @@ namespace bb { namespace image_io { template -std::string format(const char *fmt, const Rest &... rest) { +std::string format(const char *fmt, const Rest &...rest) { int length = snprintf(NULL, 0, fmt, rest...) + 1; // Explicit place for null termination std::vector buf(length, 0); snprintf(&buf[0], length, fmt, rest...); @@ -158,8 +157,9 @@ std::tuple parse_url(const std::string &url) { template class ImageSequence { - public: - ImageSequence(const std::string& session_id, const std::string& url) : idx_(0) { +public: + ImageSequence(const std::string &session_id, const std::string &url) + : idx_(0) { namespace fs = std::filesystem; std::string host_name; @@ -196,18 +196,17 @@ class ImageSequence { zf.extractall(dir_path.string()); } else { std::ofstream ofs(dir_path / fs::path(url).filename(), std::ios::binary); - ofs.write(reinterpret_cast(data.data()), data.size()); + ofs.write(reinterpret_cast(data.data()), data.size()); } - for (auto& d : fs::directory_iterator(dir_path)) { + for (auto &d : fs::directory_iterator(dir_path)) { paths_.push_back(d.path()); } // Dictionary order std::sort(paths_.begin(), paths_.end()); + } - } - - void get(int width, int height, int imread_flags, Halide::Runtime::Buffer &buf) { + void get(int width, int height, int imread_flags, Halide::Runtime::Buffer &buf) { namespace fs = std::filesystem; auto path = paths_[idx_]; @@ -215,64 +214,59 @@ class ImageSequence { std::ifstream ifs(path, std::ios::binary); std::vector img_data(size); - ifs.read(reinterpret_cast(img_data.data()), size); + ifs.read(reinterpret_cast(img_data.data()), size); if (path.extension() == ".raw") { switch (imread_flags) { - case IMREAD_GRAYSCALE: - if (size == width * height * sizeof(uint8_t)) { - Halide::Runtime::Buffer buf_8(std::vector{width, height}); //read in 8 bit - std::memcpy(buf_8.data(), img_data.data(), size); // set_img_data - auto buf_16 = Halide::Tools::ImageTypeConversion::convert_image(buf_8, halide_type_of()); - buf.copy_from(buf_16); - } else if (size == width * height * sizeof(uint16_t)) { - std::memcpy(buf.data(), img_data.data(), size); - } else { - throw std::runtime_error("Unsupported raw format"); - } - break; - case IMREAD_COLOR: - if (size == 3 * width * height * sizeof(uint8_t)) { - // Expect interleaved RGB - Halide::Runtime::Buffer buf_interleaved = Halide::Runtime::Buffer ::make_interleaved(width, height, 3); ; - std::memcpy(buf_interleaved.data(), img_data.data(), size); // set_img_data - auto buffer_planar = buf_interleaved.copy_to_planar(); - buf.copy_from(buffer_planar); - } else { - throw std::runtime_error("Unsupported raw format"); - } - break; - default: - throw std::runtime_error("Unsupported flags"); + case IMREAD_GRAYSCALE: + if (size == width * height * sizeof(uint8_t)) { + Halide::Runtime::Buffer buf_8(std::vector{width, height}); // read in 8 bit + std::memcpy(buf_8.data(), img_data.data(), size); // set_img_data + auto buf_16 = Halide::Tools::ImageTypeConversion::convert_image(buf_8, halide_type_of()); + buf.copy_from(buf_16); + } else if (size == width * height * sizeof(uint16_t)) { + std::memcpy(buf.data(), img_data.data(), size); + } else { + throw std::runtime_error("Unsupported raw format"); + } + break; + case IMREAD_COLOR: + if (size == 3 * width * height * sizeof(uint8_t)) { + // Expect interleaved RGB + Halide::Runtime::Buffer buf_interleaved = Halide::Runtime::Buffer::make_interleaved(width, height, 3); + ; + std::memcpy(buf_interleaved.data(), img_data.data(), size); // set_img_data + auto buffer_planar = buf_interleaved.copy_to_planar(); + buf.copy_from(buffer_planar); + } else { + throw std::runtime_error("Unsupported raw format"); + } + break; + default: + throw std::runtime_error("Unsupported flags"); } } else { switch (imread_flags) { - case IMREAD_GRAYSCALE: - { - Halide::Runtime::Buffer img_buf = Halide::Tools::load_and_convert_image(path.string()); - buf.copy_from(img_buf); - } - break; - case IMREAD_COLOR: - { - Halide::Runtime::Buffer img_buf = Halide::Tools::load_image(path.string()); - buf.copy_from(img_buf); - } - break; - default: - throw std::runtime_error("Unsupported flags"); + case IMREAD_GRAYSCALE: { + Halide::Runtime::Buffer img_buf = Halide::Tools::load_and_convert_image(path.string()); + buf.copy_from(img_buf); + } break; + case IMREAD_COLOR: { + Halide::Runtime::Buffer img_buf = Halide::Tools::load_image(path.string()); + buf.copy_from(img_buf); + } break; + default: + throw std::runtime_error("Unsupported flags"); } - } - idx_ = ((idx_+1) % paths_.size()); + idx_ = ((idx_ + 1) % paths_.size()); return; } - private: +private: int32_t idx_; std::vector paths_; }; - struct rawHeader { // ---------- 0 @@ -307,19 +301,19 @@ struct rawHeader { // PFNC // https://www.emva.org/wp-content/uploads/GenICamPixelFormatValues.pdf -#define PFNC_Mono8 0x01080001 //PFNC Monochrome 8-bit -#define PFNC_Mono10 0x01100003 //PFNC Monochrome 10-bit unpacked -#define PFNC_Mono12 0x01100005 //PFNC Monochrome 12-bit unpacked -#define PFNC_RGB8 0x02180014 //PFNC Red-Green-Blue 8-bit -#define PFNC_BGR8 0x02180015 //PFNC Blue-Green-Red 8-bit - -#define PFNC_BayerBG8 0x0108000B //PFNC Bayer Blue-Green 8-bit -#define PFNC_BayerBG10 0x0110000F //PFNC Bayer Blue-Green 10-bit unpacked -#define PFNC_BayerBG12 0x01100013 //PFNC Bayer Blue-Green 12-bit unpacked - -#define PFNC_BayerGR8 0x01080008 //PFNC Bayer Green-Red 8-bit -#define PFNC_BayerGR12 0x01100010 //PFNC Bayer Green-Red 12-bit unpacked -#define PFNC_YCbCr422_8 0x0210003B //PFNC YCbCr 4:2:2 8-bit +#define PFNC_Mono8 0x01080001 // PFNC Monochrome 8-bit +#define PFNC_Mono10 0x01100003 // PFNC Monochrome 10-bit unpacked +#define PFNC_Mono12 0x01100005 // PFNC Monochrome 12-bit unpacked +#define PFNC_RGB8 0x02180014 // PFNC Red-Green-Blue 8-bit +#define PFNC_BGR8 0x02180015 // PFNC Blue-Green-Red 8-bit + +#define PFNC_BayerBG8 0x0108000B // PFNC Bayer Blue-Green 8-bit +#define PFNC_BayerBG10 0x0110000F // PFNC Bayer Blue-Green 10-bit unpacked +#define PFNC_BayerBG12 0x01100013 // PFNC Bayer Blue-Green 12-bit unpacked + +#define PFNC_BayerGR8 0x01080008 // PFNC Bayer Green-Red 8-bit +#define PFNC_BayerGR12 0x01100010 // PFNC Bayer Green-Red 12-bit unpacked +#define PFNC_YCbCr422_8 0x0210003B // PFNC YCbCr 4:2:2 8-bit } // namespace image_io } // namespace bb diff --git a/src/bb/image-io/rt_display.h b/src/bb/image-io/rt_display.h index 5a17039e..9089520f 100644 --- a/src/bb/image-io/rt_display.h +++ b/src/bb/image-io/rt_display.h @@ -121,12 +121,12 @@ extern "C" ION_EXPORT int ion_bb_image_io_gui_display(halide_buffer_t *in, int w in->dim[2].extent = height; } else { if (getenv("DISPLAY")) { - auto& cv(ion::bb::OpenCV::get_instance()); + auto &cv(ion::bb::OpenCV::get_instance()); Halide::Runtime::Buffer ibuf(*in); ibuf.copy_to_host(); auto img = cv.cvCreateMatHeader(height, width, CV_MAKETYPE(CV_8U, 3)); - cv.cvSetData(img, in->host, 3*width*sizeof(uint8_t)); + cv.cvSetData(img, in->host, 3 * width * sizeof(uint8_t)); auto name = "img" + std::to_string(idx); cv.cvShowImage(name.c_str(), img); diff --git a/src/bb/image-io/rt_file.h b/src/bb/image-io/rt_file.h index 5e075aee..f4e1f5ba 100644 --- a/src/bb/image-io/rt_file.h +++ b/src/bb/image-io/rt_file.h @@ -19,7 +19,6 @@ #include "opencv_loader.h" - extern "C" int ION_EXPORT ion_bb_image_io_color_data_loader(halide_buffer_t *session_id_buf, halide_buffer_t *url_buf, int32_t width, int32_t height, halide_buffer_t *out) { using namespace ion::bb::image_io; @@ -93,7 +92,7 @@ extern "C" int ION_EXPORT ion_bb_image_io_image_saver(halide_buffer_t *in, int32 in->dim[2].extent = height; } else { Halide::Runtime::Buffer obuf = Halide::Runtime::Buffer::make_interleaved(width, height, 3); - std::memcpy(obuf.data(), in->host, 3* width*height*sizeof(uint8_t)); + std::memcpy(obuf.data(), in->host, 3 * width * height * sizeof(uint8_t)); Halide::Tools::save_image(obuf, reinterpret_cast(path->host)); } } catch (const std::exception &e) { @@ -108,32 +107,28 @@ extern "C" int ION_EXPORT ion_bb_image_io_image_saver(halide_buffer_t *in, int32 } ION_REGISTER_EXTERN(ion_bb_image_io_image_saver); - namespace { class Writer { public: - static Writer& get_instance(const std::string& id, std::vector& payload_size, const ::std::string& output_directory, bool write_framecount, const std::string& prefix = "raw-") - { + static Writer &get_instance(const std::string &id, std::vector &payload_size, const ::std::string &output_directory, bool write_framecount, const std::string &prefix = "raw-") { if (instances.count(id) == 0) { - instances[id] = std::unique_ptr(new Writer(payload_size, output_directory, write_framecount, prefix )); + instances[id] = std::unique_ptr(new Writer(payload_size, output_directory, write_framecount, prefix)); } return *instances[id]; } ~Writer() { - if (!disposed_){ + if (!disposed_) { ion::log::debug("Trying to call dispose from distructor since disposed_ is {}", disposed_); dispose(); } - } - void post_image(std::vector& outs, std::vector& size, - ion::bb::image_io::rawHeader& header_info, void* framecounts) - { - if (with_header_){ + void post_image(std::vector &outs, std::vector &size, + ion::bb::image_io::rawHeader &header_info, void *framecounts) { + if (with_header_) { write_config_file(header_info); } ::std::unique_lock<::std::mutex> lock(mutex_); @@ -141,11 +136,11 @@ class Writer { if (ep_) { ::std::rethrow_exception(ep_); } - uint8_t* buffer = buf_queue_.front(); + uint8_t *buffer = buf_queue_.front(); buf_queue_.pop(); size_t offset = 0; - for (int i = 0; i < outs.size(); ++i){ - ::std::memcpy(buffer + offset, reinterpret_cast(framecounts) + i, sizeof(int32_t)); + for (int i = 0; i < outs.size(); ++i) { + ::std::memcpy(buffer + offset, reinterpret_cast(framecounts) + i, sizeof(int32_t)); offset += sizeof(int32_t); ::std::memcpy(buffer + offset, outs[i], size[i]); offset += size[i]; @@ -154,9 +149,8 @@ class Writer { task_cv_.notify_one(); } - void post_gendc(std::vector& outs, std::vector& size, ion::bb::image_io::rawHeader& header_info) - { - if (with_header_){ + void post_gendc(std::vector &outs, std::vector &size, ion::bb::image_io::rawHeader &header_info) { + if (with_header_) { write_config_file(header_info); } ::std::unique_lock<::std::mutex> lock(mutex_); @@ -164,10 +158,10 @@ class Writer { if (ep_) { ::std::rethrow_exception(ep_); } - uint8_t* buffer = buf_queue_.front(); + uint8_t *buffer = buf_queue_.front(); buf_queue_.pop(); size_t offset = 0; - for (int i = 0; i < outs.size(); ++i){ + for (int i = 0; i < outs.size(); ++i) { ::std::memcpy(buffer + offset, outs[i], size[i]); offset += size[i]; } @@ -176,7 +170,7 @@ class Writer { } void dispose() { - ion::log::debug("Writer::dispose() :: is called"); + ion::log::debug("Writer::dispose() :: is called"); // Already disposed if thread is not joinable if (thread_ && thread_->joinable()) { keep_running_ = false; @@ -184,25 +178,23 @@ class Writer { thread_->join(); thread_ = nullptr; } - ion::log::debug("Writer::dispose() :: is finished"); - disposed_ = true; + ion::log::debug("Writer::dispose() :: is finished"); + disposed_ = true; } - - static void release_instance(const char * id) { + static void release_instance(const char *id) { ion::log::debug("Writer::release_instance() :: is called"); if (instances.count(id) == 0) { - return; + return; } - Writer & writer = *instances[id].get(); + Writer &writer = *instances[id].get(); writer.dispose(); instances.erase(id); ion::log::debug("Writer::release_instance() :: Instance is delete"); + } - } - - void write_config_file(ion::bb::image_io::rawHeader& header_info){ + void write_config_file(ion::bb::image_io::rawHeader &header_info) { nlohmann::json j_sensor; j_sensor["prefix"] = prefix_; j_sensor["framerate"] = header_info.fps_; @@ -219,13 +211,12 @@ class Writer { } private: - Writer(std::vector& payload_size, const std::string& output_directory, bool write_framecount, const std::string& prefix) - : keep_running_(true), output_directory_(output_directory), with_header_(true), disposed_(false), prefix_(prefix) - { + Writer(std::vector &payload_size, const std::string &output_directory, bool write_framecount, const std::string &prefix) + : keep_running_(true), output_directory_(output_directory), with_header_(true), disposed_(false), prefix_(prefix) { int total_payload_size = 0; - for (auto s : payload_size){ + for (auto s : payload_size) { total_payload_size += s; - if (write_framecount){ + if (write_framecount) { total_payload_size += sizeof(int32_t); } } @@ -243,20 +234,19 @@ class Writer { int get_buffer_num(int width, int height, int num_sensor = 2, int data_in_byte = 2) { // fix the memory size 2GB const double memory_size_in_MB = 2048.0; - return static_cast( memory_size_in_MB * 1024 * 1024 / width / height / num_sensor / data_in_byte); + return static_cast(memory_size_in_MB * 1024 * 1024 / width / height / num_sensor / data_in_byte); } int get_buffer_num(int32_t payloadsize) { // fix the memory size 2GB const double memory_size_in_MB = 2048.0; - return static_cast( memory_size_in_MB * 1024 * 1024 / payloadsize); + return static_cast(memory_size_in_MB * 1024 * 1024 / payloadsize); } - static void entry_point(Writer* obj) { + static void entry_point(Writer *obj) { try { obj->thread_main(); - } - catch (...) { + } catch (...) { ::std::unique_lock<::std::mutex> lock(obj->mutex_); obj->ep_ = ::std::current_exception(); } @@ -264,7 +254,7 @@ class Writer { void thread_main() { uint32_t frame_count; - uint8_t* buffer; + uint8_t *buffer; size_t size; // Main loop @@ -288,7 +278,7 @@ class Writer { ofs_ = ::std::ofstream(output_directory_ / (prefix_ + ::std::to_string(file_idx++) + ".bin"), ::std::ios::binary); } - ofs_.write(reinterpret_cast(buffer), size); + ofs_.write(reinterpret_cast(buffer), size); { ::std::unique_lock<::std::mutex> lock(mutex_); @@ -315,8 +305,7 @@ class Writer { ofs_ = ::std::ofstream(output_directory_ / (prefix_ + ::std::to_string(file_idx++) + ".bin"), ::std::ios::binary); } - - ofs_.write(reinterpret_cast(buffer), size); + ofs_.write(reinterpret_cast(buffer), size); { ::std::unique_lock<::std::mutex> lock(mutex_); @@ -330,14 +319,14 @@ class Writer { ofs_.close(); } - static ::std::unordered_map < ::std::string, std::unique_ptr> instances; // declares Writer::instance + static ::std::unordered_map<::std::string, std::unique_ptr> instances; // declares Writer::instance ::std::shared_ptr<::std::thread> thread_; ::std::vector<::std::vector> buffers_; ::std::mutex mutex_; ::std::condition_variable buf_cv_; ::std::condition_variable task_cv_; - ::std::queue buf_queue_; - ::std::queue<::std::tuple> task_queue_; + ::std::queue buf_queue_; + ::std::queue<::std::tuple> task_queue_; bool keep_running_; ::std::exception_ptr ep_; ::std::ofstream ofs_; @@ -350,28 +339,23 @@ class Writer { bool with_header_; }; -::std::unordered_map< ::std::string, std::unique_ptr> Writer::instances; // defines Writer::instance -} // namespace +::std::unordered_map<::std::string, std::unique_ptr> Writer::instances; // defines Writer::instance +} // namespace - -extern "C" -int ION_EXPORT writer_dispose(const char *id) { +extern "C" int ION_EXPORT writer_dispose(const char *id) { Writer::release_instance(id); return 0; } - -extern "C" ION_EXPORT -int ion_bb_image_io_binary_gendc_saver( halide_buffer_t * id_buf, halide_buffer_t * gendc, halide_buffer_t * deviceinfo, - int payloadsize, halide_buffer_t* output_directory_buf, halide_buffer_t* prefix_buf, - halide_buffer_t * out) - { +extern "C" ION_EXPORT int ion_bb_image_io_binary_gendc_saver(halide_buffer_t *id_buf, halide_buffer_t *gendc, halide_buffer_t *deviceinfo, + int payloadsize, halide_buffer_t *output_directory_buf, halide_buffer_t *prefix_buf, + halide_buffer_t *out) { try { const std::string id(reinterpret_cast(id_buf->host)); - const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); - std::vectorpayloadsize_list{payloadsize}; - const ::std::string prefix(reinterpret_cast(prefix_buf->host)); - auto& w(Writer::get_instance(id,payloadsize_list, output_directory, false, prefix)); + const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); + std::vector payloadsize_list{payloadsize}; + const ::std::string prefix(reinterpret_cast(prefix_buf->host)); + auto &w(Writer::get_instance(id, payloadsize_list, output_directory, false, prefix)); if (gendc->is_bounds_query() || deviceinfo->is_bounds_query()) { if (gendc->is_bounds_query()) { gendc->dim[0].min = 0; @@ -382,25 +366,20 @@ int ion_bb_image_io_binary_gendc_saver( halide_buffer_t * id_buf, halide_buffer_ deviceinfo->dim[0].extent = sizeof(ion::bb::image_io::rawHeader); } return 0; - } - else { + } else { ion::bb::image_io::rawHeader header_info; ::memcpy(&header_info, deviceinfo->host, sizeof(ion::bb::image_io::rawHeader)); std::vector obufs{gendc->host}; std::vector size_in_bytes{gendc->size_in_bytes()}; w.post_gendc(obufs, size_in_bytes, header_info); - - } return 0; - } - catch (const ::std::exception& e) { + } catch (const ::std::exception &e) { ::std::cerr << e.what() << ::std::endl; return -1; - } - catch (...) { + } catch (...) { ::std::cerr << "Unknown error" << ::std::endl; return -1; } @@ -408,22 +387,20 @@ int ion_bb_image_io_binary_gendc_saver( halide_buffer_t * id_buf, halide_buffer_ ION_REGISTER_EXTERN(ion_bb_image_io_binary_gendc_saver); -extern "C" ION_EXPORT -int ion_bb_image_io_binary_image_saver( - halide_buffer_t * id_buf, - halide_buffer_t * image, halide_buffer_t * deviceinfo, halide_buffer_t * frame_count, - int width, int height, int dim, int byte_depth, halide_buffer_t* output_directory_buf, - halide_buffer_t* prefix_buf, - halide_buffer_t * out) - { +extern "C" ION_EXPORT int ion_bb_image_io_binary_image_saver( + halide_buffer_t *id_buf, + halide_buffer_t *image, halide_buffer_t *deviceinfo, halide_buffer_t *frame_count, + int width, int height, int dim, int byte_depth, halide_buffer_t *output_directory_buf, + halide_buffer_t *prefix_buf, + halide_buffer_t *out) { try { int num_output = 1; const std::string id(reinterpret_cast(id_buf->host)); int32_t frame_size = dim == 2 ? width * height * byte_depth : width * height * 3 * byte_depth; - std::vectorframe_size_list{frame_size}; - const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); - const ::std::string prefix(reinterpret_cast(prefix_buf->host)); - auto& w(Writer::get_instance(id, frame_size_list, output_directory, true, prefix)); + std::vector frame_size_list{frame_size}; + const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); + const ::std::string prefix(reinterpret_cast(prefix_buf->host)); + auto &w(Writer::get_instance(id, frame_size_list, output_directory, true, prefix)); if (image->is_bounds_query() || deviceinfo->is_bounds_query() || frame_count->is_bounds_query()) { if (image->is_bounds_query()) { @@ -431,7 +408,7 @@ int ion_bb_image_io_binary_image_saver( image->dim[0].extent = width; image->dim[1].min = 0; image->dim[1].extent = height; - if (dim == 3){ + if (dim == 3) { image->dim[2].min = 0; image->dim[2].extent = 3; } @@ -445,8 +422,7 @@ int ion_bb_image_io_binary_image_saver( frame_count->dim[0].extent = num_output; } return 0; - } - else { + } else { ion::bb::image_io::rawHeader header_info; memcpy(&header_info, deviceinfo->host, sizeof(ion::bb::image_io::rawHeader)); std::vector header_infos{header_info}; @@ -457,12 +433,10 @@ int ion_bb_image_io_binary_image_saver( } return 0; - } - catch (const ::std::exception& e) { + } catch (const ::std::exception &e) { ::std::cerr << e.what() << ::std::endl; return -1; - } - catch (...) { + } catch (...) { ::std::cerr << "Unknown error" << ::std::endl; return -1; } @@ -472,160 +446,156 @@ ION_REGISTER_EXTERN(ion_bb_image_io_binary_image_saver); namespace { - class Reader { - public: - static Reader& get_instance(::std::string session_id, int width, int height, const ::std::string& output_directory) { - auto it = instances.find(session_id); - if (it == instances.end()) { - instances[session_id] = std::unique_ptr(new Reader(width, height, output_directory)); - } - - return *instances[session_id]; +class Reader { +public: + static Reader &get_instance(::std::string session_id, int width, int height, const ::std::string &output_directory) { + auto it = instances.find(session_id); + if (it == instances.end()) { + instances[session_id] = std::unique_ptr(new Reader(width, height, output_directory)); } - void get(uint8_t* ptr0, uint8_t* ptr1, size_t size) { + return *instances[session_id]; + } - current_idx_ = file_idx_; + void get(uint8_t *ptr0, uint8_t *ptr1, size_t size) { - if (finished_) { - return; - } + current_idx_ = file_idx_; - if (read_count_ < offset_frame_count_) { - ::std::memset(ptr0, 0, size); - ::std::memset(ptr1, 0, size); - } - else { - uint32_t frame_count = 0; - ifs_.read(reinterpret_cast(&frame_count), sizeof(frame_count)); + if (finished_) { + return; + } - if (frame_count != (latest_frame_count_ + 1)) { - ifs_.seekg(-static_cast(sizeof(frame_count)), ::std::ios::cur); - } - else { - ifs_.read(reinterpret_cast(latest_frame0_.data()), size); - ifs_.read(reinterpret_cast(latest_frame1_.data()), size); - } + if (read_count_ < offset_frame_count_) { + ::std::memset(ptr0, 0, size); + ::std::memset(ptr1, 0, size); + } else { + uint32_t frame_count = 0; + ifs_.read(reinterpret_cast(&frame_count), sizeof(frame_count)); + + if (frame_count != (latest_frame_count_ + 1)) { + ifs_.seekg(-static_cast(sizeof(frame_count)), ::std::ios::cur); + } else { + ifs_.read(reinterpret_cast(latest_frame0_.data()), size); + ifs_.read(reinterpret_cast(latest_frame1_.data()), size); + } - ::std::memcpy(ptr0, latest_frame0_.data(), size); - ::std::memcpy(ptr1, latest_frame1_.data(), size); + ::std::memcpy(ptr0, latest_frame0_.data(), size); + ::std::memcpy(ptr1, latest_frame1_.data(), size); - latest_frame_count_++; + latest_frame_count_++; - // rotate - ifs_.peek(); - if (ifs_.eof()) { - open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); - if (finished_) { - ifs_ = ::std::ifstream(); - } + // rotate + ifs_.peek(); + if (ifs_.eof()) { + open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); + if (finished_) { + ifs_ = ::std::ifstream(); } } - read_count_++; - } - - void close() { - ifs_.close(); } + read_count_++; + } - bool get_finished() const { - return finished_; - } + void close() { + ifs_.close(); + } - uint32_t get_index() { - return current_idx_; - } + bool get_finished() const { + return finished_; + } - void release_instance(const ::std::string& session_id) { - instances.erase(session_id); - } + uint32_t get_index() { + return current_idx_; + } - private: - Reader(int width, int height, const ::std::string& output_directory) - : width_(width), height_(height), output_directory_(output_directory), - file_idx_(0), latest_frame0_(width* height), latest_frame1_(width* height), - latest_frame_count_((::std::numeric_limits::max)()), read_count_(0), finished_(false) - { + void release_instance(const ::std::string &session_id) { + instances.erase(session_id); + } - open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); - if (finished_) { - return; - } +private: + Reader(int width, int height, const ::std::string &output_directory) + : width_(width), height_(height), output_directory_(output_directory), + file_idx_(0), latest_frame0_(width * height), latest_frame1_(width * height), + latest_frame_count_((::std::numeric_limits::max)()), read_count_(0), finished_(false) { - // Determine counter might be reset to zero (may dropped first few frames) - uint32_t prev_frame_count = 0; - const size_t size = static_cast(width * height * sizeof(uint16_t)); - while (true) { - uint32_t frame_count = 0; - ifs_.read(reinterpret_cast(&frame_count), sizeof(frame_count)); - ifs_.seekg(2 * size, ::std::ios::cur); - if (prev_frame_count > frame_count) { - ifs_.seekg(-static_cast(sizeof(frame_count)) - 2 * size, ::std::ios::cur); - offset_frame_count_ = frame_count; - break; - } - prev_frame_count = frame_count; - ifs_.peek(); + open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); + if (finished_) { + return; + } - if (ifs_.eof()) { + // Determine counter might be reset to zero (may dropped first few frames) + uint32_t prev_frame_count = 0; + const size_t size = static_cast(width * height * sizeof(uint16_t)); + while (true) { + uint32_t frame_count = 0; + ifs_.read(reinterpret_cast(&frame_count), sizeof(frame_count)); + ifs_.seekg(2 * size, ::std::ios::cur); + if (prev_frame_count > frame_count) { + ifs_.seekg(-static_cast(sizeof(frame_count)) - 2 * size, ::std::ios::cur); + offset_frame_count_ = frame_count; + break; + } + prev_frame_count = frame_count; + ifs_.peek(); + + if (ifs_.eof()) { + open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); + if (finished_) { + // Seek to first file and set offset when We cannot find base frame. + file_idx_ = 0; open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); - if (finished_) { - // Seek to first file and set offset when We cannot find base frame. - file_idx_ = 0; - open_and_check(width_, height_, output_directory_, file_idx_, ifs_, &finished_); - ifs_.read(reinterpret_cast(&offset_frame_count_), sizeof(offset_frame_count_)); - ifs_.seekg(-static_cast(sizeof(offset_frame_count_)), ::std::ios::cur); - finished_ = false; - read_count_ = offset_frame_count_; - latest_frame_count_ = read_count_ - 1; - break; - } + ifs_.read(reinterpret_cast(&offset_frame_count_), sizeof(offset_frame_count_)); + ifs_.seekg(-static_cast(sizeof(offset_frame_count_)), ::std::ios::cur); + finished_ = false; + read_count_ = offset_frame_count_; + latest_frame_count_ = read_count_ - 1; + break; } } - current_idx_ = file_idx_; } + current_idx_ = file_idx_; + } - void open_and_check(uint32_t width, uint32_t height, const std::filesystem::path output_directory, uint32_t& file_idx, ::std::ifstream& ifs, bool* finished) { - auto file_path = output_directory / ("raw-" + ::std::to_string(file_idx++) + ".bin"); - - ifs = ::std::ifstream(file_path, ::std::ios::binary); - if (ifs.fail()) { - *finished = true; - return; - } + void open_and_check(uint32_t width, uint32_t height, const std::filesystem::path output_directory, uint32_t &file_idx, ::std::ifstream &ifs, bool *finished) { + auto file_path = output_directory / ("raw-" + ::std::to_string(file_idx++) + ".bin"); - // skip header (size is 512) - ifs.seekg(512, ::std::ios_base::beg); + ifs = ::std::ifstream(file_path, ::std::ios::binary); + if (ifs.fail()) { + *finished = true; + return; } - uint32_t width_; - uint32_t height_; - std::filesystem::path output_directory_; - uint32_t file_idx_; - ::std::vector latest_frame0_; - ::std::vector latest_frame1_; - uint32_t latest_frame_count_; - uint32_t offset_frame_count_; - uint32_t read_count_; - ::std::ifstream ifs_; - bool finished_; - - uint32_t current_idx_; - static ::std::unordered_map < ::std::string, std::unique_ptr> instances; // declares Writer::instance - }; - - ::std::unordered_map< ::std::string, std::unique_ptr> Reader::instances; // defines Writer::instance -} + // skip header (size is 512) + ifs.seekg(512, ::std::ios_base::beg); + } -extern "C" ION_EXPORT -int binaryloader(halide_buffer_t *session_id_buf, int width, int height, halide_buffer_t * output_directory_buf, - halide_buffer_t * out0, halide_buffer_t * out1) { + uint32_t width_; + uint32_t height_; + std::filesystem::path output_directory_; + uint32_t file_idx_; + ::std::vector latest_frame0_; + ::std::vector latest_frame1_; + uint32_t latest_frame_count_; + uint32_t offset_frame_count_; + uint32_t read_count_; + ::std::ifstream ifs_; + bool finished_; + + uint32_t current_idx_; + static ::std::unordered_map<::std::string, std::unique_ptr> instances; // declares Writer::instance +}; + +::std::unordered_map<::std::string, std::unique_ptr> Reader::instances; // defines Writer::instance +} // namespace + +extern "C" ION_EXPORT int binaryloader(halide_buffer_t *session_id_buf, int width, int height, halide_buffer_t *output_directory_buf, + halide_buffer_t *out0, halide_buffer_t *out1) { try { - const ::std::string session_id(reinterpret_cast(session_id_buf->host)); - const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); + const ::std::string session_id(reinterpret_cast(session_id_buf->host)); + const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); - auto& r(Reader::get_instance(session_id, width, height, output_directory)); + auto &r(Reader::get_instance(session_id, width, height, output_directory)); if (out0->is_bounds_query() || out1->is_bounds_query()) { if (out0->is_bounds_query()) { @@ -640,27 +610,23 @@ int binaryloader(halide_buffer_t *session_id_buf, int width, int height, halide_ out1->dim[1].min = 0; out1->dim[1].extent = height; } - } - else { - r.get(out0->host, out1->host, out0->size_in_bytes()); + } else { + r.get(out0->host, out1->host, out0->size_in_bytes()); } return 0; - } - catch (const ::std::exception& e) { + } catch (const ::std::exception &e) { ::std::cerr << e.what() << ::std::endl; return -1; - } - catch (...) { + } catch (...) { ::std::cerr << "Unknown error" << ::std::endl; return -1; } } ION_REGISTER_EXTERN(binaryloader); -extern "C" ION_EXPORT -int binaryloader_finished(halide_buffer_t* in0, halide_buffer_t* in1, halide_buffer_t *session_id_buf, int width, int height, - halide_buffer_t * output_directory_buf, - halide_buffer_t * finished, halide_buffer_t* bin_idx) { +extern "C" ION_EXPORT int binaryloader_finished(halide_buffer_t *in0, halide_buffer_t *in1, halide_buffer_t *session_id_buf, int width, int height, + halide_buffer_t *output_directory_buf, + halide_buffer_t *finished, halide_buffer_t *bin_idx) { try { if (in0->is_bounds_query() || in1->is_bounds_query()) { @@ -676,27 +642,24 @@ int binaryloader_finished(halide_buffer_t* in0, halide_buffer_t* in1, halide_buf in1->dim[1].min = 0; in1->dim[1].extent = height; } - } - else { - const ::std::string session_id(reinterpret_cast(session_id_buf->host)); - const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); - auto& r(Reader::get_instance(session_id, width, height, output_directory)); + } else { + const ::std::string session_id(reinterpret_cast(session_id_buf->host)); + const ::std::string output_directory(reinterpret_cast(output_directory_buf->host)); + auto &r(Reader::get_instance(session_id, width, height, output_directory)); auto finished_flag = r.get_finished(); - *reinterpret_cast(finished->host) = finished_flag; - *reinterpret_cast(bin_idx->host) = r.get_index(); - if (finished_flag) { - r.close(); - r.release_instance(session_id); - } + *reinterpret_cast(finished->host) = finished_flag; + *reinterpret_cast(bin_idx->host) = r.get_index(); + if (finished_flag) { + r.close(); + r.release_instance(session_id); + } } return 0; - } - catch (const ::std::exception& e) { + } catch (const ::std::exception &e) { ::std::cerr << e.what() << ::std::endl; return -1; - } - catch (...) { + } catch (...) { ::std::cerr << "Unknown error" << ::std::endl; return -1; } diff --git a/src/bb/image-io/rt_u3v.h b/src/bb/image-io/rt_u3v.h index 9f758256..6bc1c73f 100644 --- a/src/bb/image-io/rt_u3v.h +++ b/src/bb/image-io/rt_u3v.h @@ -15,11 +15,11 @@ #define ComponentIDIntensity 1 #ifdef _WIN32 - #define GOBJECT_FILE "gobject-2.0-0" - #define ARAVIS_FILE "aravis-0.8-0" +#define GOBJECT_FILE "gobject-2.0-0" +#define ARAVIS_FILE "aravis-0.8-0" #else - #define GOBJECT_FILE "gobject-2.0" - #define ARAVIS_FILE "aravis-0.8" +#define GOBJECT_FILE "gobject-2.0" +#define ARAVIS_FILE "aravis-0.8" #endif namespace ion { @@ -28,18 +28,16 @@ namespace image_io { class U3V { protected: - struct GError - { - uint32_t domain; - int32_t code; - const char *message; + struct GError { + uint32_t domain; + int32_t code; + const char *message; }; - enum OperationMode - { Came2USB1, - Came1USB1, - Came2USB2, - Came1USB2 + enum OperationMode { Came2USB1, + Came1USB1, + Came2USB2, + Came1USB2 }; enum FrameCountMethod { @@ -48,16 +46,16 @@ class U3V { TYPESPECIFIC3 = 1 }; - using gpointer = struct gpointer_*; + using gpointer = struct gpointer_ *; using g_object_unref_t = void (*)(gpointer); - typedef enum ArvAcquisitionMode{ + typedef enum ArvAcquisitionMode { ARV_ACQUISITION_MODE_CONTINUOUS, ARV_ACQUISITION_MODE_SINGLE_FRAME } ArvAcquisitionMode_t; - typedef enum ArvBufferStatus{ + typedef enum ArvBufferStatus { ARV_BUFFER_STATUS_UNKNOWN, ARV_BUFFER_STATUS_SUCCESS, ARV_BUFFER_STATUS_CLEARED, @@ -69,7 +67,7 @@ class U3V { ARV_BUFFER_STATUS_ABORTED } ArvBufferStatus_t; - typedef enum ArvBufferPayloadType{ + typedef enum ArvBufferPayloadType { ARV_BUFFER_PAYLOAD_TYPE_UNKNOWN, ARV_BUFFER_PAYLOAD_TYPE_IMAGE, ARV_BUFFER_PAYLOAD_TYPE_RAWDATA, @@ -80,107 +78,107 @@ class U3V { ARV_BUFFER_PAYLOAD_TYPE_JPEG2000, ARV_BUFFER_PAYLOAD_TYPE_H264, ARV_BUFFER_PAYLOAD_TYPE_MULTIZONE_IMAGE - }ArvBufferPayloadType_t; + } ArvBufferPayloadType_t; - typedef enum ArvDeviceStatus{ + typedef enum ArvDeviceStatus { ARV_DEVICE_STATUS_UNKNOWN, ARV_DEVICE_STATUS_SUCCESS, ARV_DEVICE_STATUS_TIMEOUT, ARV_DEVICE_STATUS_WRITE_ERROR - }ArvDeviceStatus_t; + } ArvDeviceStatus_t; - typedef enum ArvUvUsbMode{ + typedef enum ArvUvUsbMode { ARV_UV_USB_MODE_SYNC, ARV_UV_USB_MODE_ASYNC, ARV_UV_USB_MODE_DEFAULT = ARV_UV_USB_MODE_ASYNC } ArvUvUsbMode_t; - using ArvDevice_t = struct ArvDevice*; - using ArvFakeDevice_t = struct ArvFakeDevice*; - using ArvStream_t = struct ArvStream*; - using ArvStreamCallback_t = struct ArvStreamCallback*; - using ArvBuffer_t = struct ArvBuffer*; - using ArvGcNode_t = struct ArvGcNode*; - using ArvCamera_t = struct ArvCamera*; + using ArvDevice_t = struct ArvDevice *; + using ArvFakeDevice_t = struct ArvFakeDevice *; + using ArvStream_t = struct ArvStream *; + using ArvStreamCallback_t = struct ArvStreamCallback *; + using ArvBuffer_t = struct ArvBuffer *; + using ArvGcNode_t = struct ArvGcNode *; + using ArvCamera_t = struct ArvCamera *; - using arv_get_device_protocol_t = const char*(*)(unsigned int); + using arv_get_device_protocol_t = const char *(*)(unsigned int); - using arv_get_major_version_t = uint32_t(*)(); - using arv_get_minor_version_t = uint32_t(*)(); - using arv_get_micro_version_t = uint32_t(*)(); + using arv_get_major_version_t = uint32_t (*)(); + using arv_get_minor_version_t = uint32_t (*)(); + using arv_get_micro_version_t = uint32_t (*)(); - using arv_update_device_list_t = void(*)(); - using arv_get_n_devices_t = unsigned int(*)(); + using arv_update_device_list_t = void (*)(); + using arv_get_n_devices_t = unsigned int (*)(); - using arv_get_device_id_t = const char*(*)(unsigned int); - using arv_get_device_model_t = const char*(*)(unsigned int); - using arv_get_device_serial_nbr_t = const char*(*)(unsigned int); - using arv_open_device_t = ArvDevice*(*)(const char*, GError**); + using arv_get_device_id_t = const char *(*)(unsigned int); + using arv_get_device_model_t = const char *(*)(unsigned int); + using arv_get_device_serial_nbr_t = const char *(*)(unsigned int); + using arv_open_device_t = ArvDevice *(*)(const char *, GError **); - using arv_device_set_string_feature_value_t = void(*)(ArvDevice*, const char*, const char*, GError**); - using arv_device_set_float_feature_value_t = void(*)(ArvDevice*, const char*, double, GError**); - using arv_device_set_integer_feature_value_t = void(*)(ArvDevice*, const char*, int64_t, GError**); + using arv_device_set_string_feature_value_t = void (*)(ArvDevice *, const char *, const char *, GError **); + using arv_device_set_float_feature_value_t = void (*)(ArvDevice *, const char *, double, GError **); + using arv_device_set_integer_feature_value_t = void (*)(ArvDevice *, const char *, int64_t, GError **); - using arv_device_get_string_feature_value_t = const char *(*)(ArvDevice*, const char*, GError**); - using arv_device_get_integer_feature_value_t = int(*)(ArvDevice*, const char*, GError**); - using arv_device_get_float_feature_value_t = double(*)(ArvDevice*, const char*, GError**); + using arv_device_get_string_feature_value_t = const char *(*)(ArvDevice *, const char *, GError **); + using arv_device_get_integer_feature_value_t = int (*)(ArvDevice *, const char *, GError **); + using arv_device_get_float_feature_value_t = double (*)(ArvDevice *, const char *, GError **); - using arv_device_get_integer_feature_bounds_t = void(*)(ArvDevice*, const char*, int64_t*, int64_t*, GError**); - using arv_device_get_float_feature_bounds_t = void(*)(ArvDevice*, const char*, double*, double*, GError**); + using arv_device_get_integer_feature_bounds_t = void (*)(ArvDevice *, const char *, int64_t *, int64_t *, GError **); + using arv_device_get_float_feature_bounds_t = void (*)(ArvDevice *, const char *, double *, double *, GError **); - using arv_device_is_feature_available_t = bool(*)(ArvDevice*, const char*, GError**); + using arv_device_is_feature_available_t = bool (*)(ArvDevice *, const char *, GError **); - using arv_device_dup_register_feature_value_t = void*(*) (ArvDevice*, const char *, uint64_t *, GError **); + using arv_device_dup_register_feature_value_t = void *(*)(ArvDevice *, const char *, uint64_t *, GError **); - using arv_device_create_stream_t = ArvStream*(*)(ArvDevice*, ArvStreamCallback*, void*, GError**); + using arv_device_create_stream_t = ArvStream *(*)(ArvDevice *, ArvStreamCallback *, void *, GError **); - using arv_buffer_new_allocate_t = ArvBuffer*(*)(size_t); - using arv_stream_push_buffer_t = void(*)(ArvStream*, ArvBuffer*); + using arv_buffer_new_allocate_t = ArvBuffer *(*)(size_t); + using arv_stream_push_buffer_t = void (*)(ArvStream *, ArvBuffer *); - using arv_acquisition_mode_to_string_t = const char*(*)(ArvAcquisitionMode); - using arv_device_execute_command_t = void(*)(ArvDevice*, const char*, GError**); - using arv_stream_timeout_pop_buffer_t = ArvBuffer*(*)(ArvStream*, uint64_t); - using arv_stream_get_n_buffers_t = void(*)(ArvStream*, int32_t*, int32_t*); - using arv_buffer_get_status_t = ArvBufferStatus(*)(ArvBuffer*); - using arv_buffer_get_payload_type_t = ArvBufferPayloadType(*)(ArvBuffer*); - using arv_buffer_get_data_t = void*(*)(ArvBuffer*, size_t*); - using arv_buffer_get_part_data_t = void*(*)(ArvBuffer*, uint_fast32_t, size_t*); - using arv_buffer_get_timestamp_t = uint64_t(*)(ArvBuffer*); - using arv_device_get_feature_t = ArvGcNode*(*)(ArvDevice*, const char*); + using arv_acquisition_mode_to_string_t = const char *(*)(ArvAcquisitionMode); + using arv_device_execute_command_t = void (*)(ArvDevice *, const char *, GError **); + using arv_stream_timeout_pop_buffer_t = ArvBuffer *(*)(ArvStream *, uint64_t); + using arv_stream_get_n_buffers_t = void (*)(ArvStream *, int32_t *, int32_t *); + using arv_buffer_get_status_t = ArvBufferStatus (*)(ArvBuffer *); + using arv_buffer_get_payload_type_t = ArvBufferPayloadType (*)(ArvBuffer *); + using arv_buffer_get_data_t = void *(*)(ArvBuffer *, size_t *); + using arv_buffer_get_part_data_t = void *(*)(ArvBuffer *, uint_fast32_t, size_t *); + using arv_buffer_get_timestamp_t = uint64_t (*)(ArvBuffer *); + using arv_device_get_feature_t = ArvGcNode *(*)(ArvDevice *, const char *); - using arv_buffer_has_gendc_t = bool*(*)(ArvBuffer*); - using arv_buffer_get_gendc_descriptor_t = void*(*)(ArvBuffer*, size_t*); + using arv_buffer_has_gendc_t = bool *(*)(ArvBuffer *); + using arv_buffer_get_gendc_descriptor_t = void *(*)(ArvBuffer *, size_t *); - using arv_shutdown_t = void(*)(void); + using arv_shutdown_t = void (*)(void); - using arv_camera_new_t = ArvCamera*(*)(const char*, GError**); - using arv_camera_get_device_t = ArvDevice*(*)(ArvCamera *); - using arv_fake_device_new_t = ArvDevice*(*)(const char*, GError**); - using arv_set_fake_camera_genicam_filename_t = void(*)(const char*); - using arv_enable_interface_t = void(*)(const char*); - using arv_camera_create_stream_t = ArvStream*(*)(ArvCamera*, ArvStreamCallback*, void*, GError**); - using arv_fake_device_get_fake_camera_t = ArvCamera*(*)(ArvFakeDevice*); + using arv_camera_new_t = ArvCamera *(*)(const char *, GError **); + using arv_camera_get_device_t = ArvDevice *(*)(ArvCamera *); + using arv_fake_device_new_t = ArvDevice *(*)(const char *, GError **); + using arv_set_fake_camera_genicam_filename_t = void (*)(const char *); + using arv_enable_interface_t = void (*)(const char *); + using arv_camera_create_stream_t = ArvStream *(*)(ArvCamera *, ArvStreamCallback *, void *, GError **); + using arv_fake_device_get_fake_camera_t = ArvCamera *(*)(ArvFakeDevice *); - using arv_uv_device_set_usb_mode_t = void(*)(ArvDevice *, ArvUvUsbMode ); + using arv_uv_device_set_usb_mode_t = void (*)(ArvDevice *, ArvUvUsbMode); struct DeviceInfo { - const char* dev_id_; - ArvDevice* device_; - ArvCamera* camera_; + const char *dev_id_; + ArvDevice *device_; + ArvCamera *camera_; int32_t u3v_payload_size_; int32_t image_payload_size_; uint32_t frame_count_; - float gain_ =-1; - float exposure_ =-1; + float gain_ = -1; + float exposure_ = -1; int32_t int_gain_ = -1; int32_t int_exposure_ = -1; float exposure_range_[2]; - ArvStream* stream_; + ArvStream *stream_; // genDC int64_t data_offset_; @@ -191,17 +189,17 @@ class U3V { rawHeader header_info_; }; - public: - ~U3V(){ - if (!disposed_){ +public: + ~U3V() { + if (!disposed_) { log::debug("Trying to call dispose from distructor since disposed_ is {}", disposed_); dispose(); } } - void dispose(){ + void dispose() { log::debug("U3V::dispose() :: is called"); - for (auto i=0; i(d.stream_)); auto end = std::chrono::system_clock::now(); - log::debug("U3V::dispose() :: g_object_unref took {} ms", std::chrono::duration_cast(end-start).count()); + log::debug("U3V::dispose() :: g_object_unref took {} ms", std::chrono::duration_cast(end - start).count()); start = std::chrono::system_clock::now(); g_object_unref(reinterpret_cast(d.device_)); end = std::chrono::system_clock::now(); - log::debug("U3V::dispose() :: g_object_unref took {} ms", std::chrono::duration_cast(end-start).count()); + log::debug("U3V::dispose() :: g_object_unref took {} ms", std::chrono::duration_cast(end - start).count()); } devices_.clear(); @@ -229,96 +227,94 @@ class U3V { log::debug("U3V::dispose() :: Instance is deleted"); } - static void release_instance(const char * id) { + static void release_instance(const char *id) { log::debug("U3V::release_instance() :: is called"); if (instances_.count(id) == 0) { - return; + return; } - U3V & u3v = *instances_[id].get(); + U3V &u3v = *instances_[id].get(); u3v.dispose(); instances_.erase(id); log::debug("U3V::release_instance() :: is finished"); - - } + } void set_gain(int32_t sensor_idx, const std::string key, double v) { - if (is_param_integer_){ + if (is_param_integer_) { set_gain(sensor_idx, key, static_cast(v)); return; } - if (sensor_idx < num_sensor_ ){ - if(devices_[sensor_idx].gain_ != v){ - err_ = set(devices_[sensor_idx].device_, key.c_str(), v); + if (sensor_idx < num_sensor_) { + if (devices_[sensor_idx].gain_ != v) { + err_ = set(devices_[sensor_idx].device_, key.c_str(), v); devices_[sensor_idx].gain_ = v; } return; - }else{ + } else { throw std::runtime_error("the index number " + std::to_string(sensor_idx) + " exceeds the number of sensor " + std::to_string(num_sensor_)); } } void set_gain(int32_t sensor_idx, const std::string key, int32_t v) { - if (sensor_idx < num_sensor_ ){ - if(devices_[sensor_idx].int_gain_ != v){ - err_ = set(devices_[sensor_idx].device_, key.c_str(), static_cast(v)); + if (sensor_idx < num_sensor_) { + if (devices_[sensor_idx].int_gain_ != v) { + err_ = set(devices_[sensor_idx].device_, key.c_str(), static_cast(v)); devices_[sensor_idx].int_gain_ = v; } return; - }else{ + } else { throw std::runtime_error("the index number " + std::to_string(sensor_idx) + " exceeds the number of sensor " + std::to_string(num_sensor_)); } } void set_exposure(int32_t sensor_idx, const std::string key, double v) { - if (is_param_integer_){ + if (is_param_integer_) { set_exposure(sensor_idx, key, static_cast(v)); return; } - if (sensor_idx < num_sensor_ ){ - if(devices_[sensor_idx].exposure_ != v){ + if (sensor_idx < num_sensor_) { + if (devices_[sensor_idx].exposure_ != v) { err_ = set(devices_[sensor_idx].device_, key.c_str(), v); devices_[sensor_idx].exposure_ = v; } return; - }else{ + } else { throw std::runtime_error("the index number " + std::to_string(sensor_idx) + " exceeds the number of sensor " + std::to_string(num_sensor_)); } } void set_exposure(int32_t sensor_idx, const std::string key, int32_t v) { - if (sensor_idx < num_sensor_ ){ - if(devices_[sensor_idx].int_exposure_ != v){ + if (sensor_idx < num_sensor_) { + if (devices_[sensor_idx].int_exposure_ != v) { err_ = set(devices_[sensor_idx].device_, key.c_str(), static_cast(v)); devices_[sensor_idx].int_exposure_ = v; } return; - }else{ + } else { throw std::runtime_error("the index number " + std::to_string(sensor_idx) + " exceeds the number of sensor " + std::to_string(num_sensor_)); } } + virtual void get(std::vector> &outs){}; + virtual void get(std::vector &outs){}; - virtual void get(std::vector>& outs){}; - virtual void get(std::vector& outs){}; - - void get_frame_count(std::vector& outs){ - if (num_sensor_ != devices_.size()){ + void get_frame_count(std::vector &outs) { + if (num_sensor_ != devices_.size()) { ::memcpy(outs[0], &frame_cnt_, sizeof(uint32_t)); - }else{ - for (int nd = 0; nd < num_sensor_; nd++){ + } else { + for (int nd = 0; nd < num_sensor_; nd++) { ::memcpy(outs[nd], &devices_[nd].frame_count_, sizeof(uint32_t)); } } } - void get_device_info(std::vector& outs){ - if (sim_mode_||operation_mode_ == OperationMode::Came2USB2 || operation_mode_ == OperationMode::Came1USB1){ - for (int i = 0; i < num_sensor_; ++i){ + void get_device_info(std::vector &outs) { + if (sim_mode_ || operation_mode_ == OperationMode::Came2USB2 || operation_mode_ == OperationMode::Came1USB1) { + for (int i = 0; i < num_sensor_; ++i) { ::memcpy(outs[i], &(devices_[i].header_info_), sizeof(ion::bb::image_io::rawHeader)); log::trace("Obtained Device info USB{}", i); } @@ -329,43 +325,41 @@ class U3V { } protected: - U3V(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height , float_t fps, const std::string & pixel_format, char* dev_id = nullptr) - : gobject_(GOBJECT_FILE, true), aravis_(ARAVIS_FILE, true, true), - num_sensor_(num_sensor), frame_count_method_(FrameCountMethod::UNAVAILABLE), - frame_sync_(frame_sync), realtime_display_mode_(realtime_display_mode), is_gendc_(false), is_param_integer_(false), - devices_(num_sensor), buffers_(num_sensor), operation_mode_(OperationMode::Came1USB1), frame_cnt_(0), device_idx_(-1), disposed_(false), sim_mode_(sim_mode), order_filp_(false) - { + U3V(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height, float_t fps, const std::string &pixel_format, char *dev_id = nullptr) + : gobject_(GOBJECT_FILE, true), aravis_(ARAVIS_FILE, true, true), + num_sensor_(num_sensor), frame_count_method_(FrameCountMethod::UNAVAILABLE), + frame_sync_(frame_sync), realtime_display_mode_(realtime_display_mode), is_gendc_(false), is_param_integer_(false), + devices_(num_sensor), buffers_(num_sensor), operation_mode_(OperationMode::Came1USB1), frame_cnt_(0), device_idx_(-1), disposed_(false), sim_mode_(sim_mode), order_filp_(false) { init_symbols(); log::debug("U3V:: 24-09-03 : Tested on device 1.2"); log::info("Using aravis-{}.{}.{}", arv_get_major_version(), arv_get_minor_version(), arv_get_micro_version()); } - void init_symbols_gobject() { - if (!gobject_.is_available()) { + if (!gobject_.is_available()) { throw ::std::runtime_error("libgobject-2.0 is unavailable on your system."); } - #define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ - LOCAL_VAR = gobject_.get_symbol(TARGET_SYMBOL); \ - if (LOCAL_VAR == nullptr) { \ - throw ::std::runtime_error( \ - TARGET_SYMBOL " is unavailable on gobject-2.0"); \ - } +#define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ + LOCAL_VAR = gobject_.get_symbol(TARGET_SYMBOL); \ + if (LOCAL_VAR == nullptr) { \ + throw ::std::runtime_error( \ + TARGET_SYMBOL " is unavailable on gobject-2.0"); \ + } GET_SYMBOL(g_object_unref, "g_object_unref"); - #undef GET_SYMBOL +#undef GET_SYMBOL } void init_symbols_aravis() { - #define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ - LOCAL_VAR = aravis_.get_symbol(TARGET_SYMBOL); \ - if (LOCAL_VAR == nullptr) { \ - throw ::std::runtime_error( \ - TARGET_SYMBOL " is unavailable on aravis-0.8"); \ - } +#define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ + LOCAL_VAR = aravis_.get_symbol(TARGET_SYMBOL); \ + if (LOCAL_VAR == nullptr) { \ + throw ::std::runtime_error( \ + TARGET_SYMBOL " is unavailable on aravis-0.8"); \ + } GET_SYMBOL(arv_get_major_version, "arv_get_major_version"); GET_SYMBOL(arv_get_minor_version, "arv_get_minor_version"); @@ -420,7 +414,7 @@ class U3V { GET_SYMBOL(arv_enable_interface, "arv_enable_interface"); GET_SYMBOL(arv_fake_device_get_fake_camera, "arv_fake_device_get_fake_camera"); GET_SYMBOL(arv_uv_device_set_usb_mode, "arv_uv_device_set_usb_mode"); - #undef GET_SYMBOL +#undef GET_SYMBOL } void init_symbols() { @@ -428,45 +422,46 @@ class U3V { init_symbols_aravis(); } - int32_t get_frame_count_from_genDC_descriptor(ArvBuffer * buf, DeviceInfo& d){ - int32_t frame_count = 0;; - memcpy (&frame_count, ((char *) arv_buffer_get_data(buf, nullptr) + d.framecount_offset_), sizeof(int32_t)); + int32_t get_frame_count_from_genDC_descriptor(ArvBuffer *buf, DeviceInfo &d) { + int32_t frame_count = 0; + ; + memcpy(&frame_count, ((char *)arv_buffer_get_data(buf, nullptr) + d.framecount_offset_), sizeof(int32_t)); return frame_count; } template - GError* set(ArvDevice* dev_handle, const char* key, T v) { + GError *set(ArvDevice *dev_handle, const char *key, T v) { return set_feature_value(dev_handle, key, v); } template - GError* get(ArvDevice* dev_handle, const char* key, T* v) { + GError *get(ArvDevice *dev_handle, const char *key, T *v) { T vp; err_ = get_feature_value(dev_handle, key, vp); *v = vp; return err_; } - GError* set_feature_value(ArvDevice *device, const char *feature, const char *value){ - arv_device_set_string_feature_value (device, feature, value, &err_); + GError *set_feature_value(ArvDevice *device, const char *feature, const char *value) { + arv_device_set_string_feature_value(device, feature, value, &err_); return err_; } - GError* set_feature_value(ArvDevice *device, const char *feature, double value){ + GError *set_feature_value(ArvDevice *device, const char *feature, double value) { double min_v, max_v; - arv_device_get_float_feature_bounds (device, feature, &min_v, &max_v, &err_); + arv_device_get_float_feature_bounds(device, feature, &min_v, &max_v, &err_); if (err_ != nullptr) { return err_; } value = (std::max)(min_v, value); value = (std::min)(max_v, value); - arv_device_set_float_feature_value (device, feature, value, &err_); + arv_device_set_float_feature_value(device, feature, value, &err_); return err_; } - GError* set_feature_value(ArvDevice *device, const char *feature, int64_t value){ + GError *set_feature_value(ArvDevice *device, const char *feature, int64_t value) { int64_t min_v, max_v; arv_device_get_integer_feature_bounds(device, feature, &min_v, &max_v, &err_); if (err_ != nullptr) { @@ -475,17 +470,17 @@ class U3V { value = (std::max)(min_v, value); value = (std::min)(max_v, value); - arv_device_set_integer_feature_value (device, feature, value, &err_); + arv_device_set_integer_feature_value(device, feature, value, &err_); return err_; } - GError* get_feature_value(ArvDevice *device, const char *feature, int64_t& value){ + GError *get_feature_value(ArvDevice *device, const char *feature, int64_t &value) { value = arv_device_get_integer_feature_value(device, feature, &err_); return err_; } - void validate_user_input(int32_t num_detected_device, char* dev_id){ - if (num_detected_device < num_sensor_){ + void validate_user_input(int32_t num_detected_device, char *dev_id) { + if (num_detected_device < num_sensor_) { log::info("{} device is found; but the num_sensor is set to {}", num_detected_device, num_sensor_); throw std::runtime_error("Device number is not match, please set num_device again"); } @@ -500,8 +495,8 @@ class U3V { log::info("Acquisition option::{} is {}", "realtime_display_mode_", realtime_display_mode_); } - void command_acquisition_mode_contd_and_start(){ - for (auto i=0; imessage); @@ -516,63 +511,63 @@ class U3V { } } - void open_fake_devices(int32_t width, int32_t height , float_t fps, const std::string & pixel_format){ + void open_fake_devices(int32_t width, int32_t height, float_t fps, const std::string &pixel_format) { auto path = std::getenv("GENICAM_FILENAME"); - if (path == nullptr){ + if (path == nullptr) { throw std::runtime_error("Please define GENICAM_FILENAME by `set GENICAM_FILENAME=` or `export GENICAM_FILENAME=`"); } pixel_format_ = pixel_format; - arv_set_fake_camera_genicam_filename (path); + arv_set_fake_camera_genicam_filename(path); - arv_enable_interface ("Fake"); + arv_enable_interface("Fake"); log::info("Creating U3V instance with {} fake sensors...", num_sensor_); - auto fake_camera0 = arv_camera_new ("Fake_1", &err_); + auto fake_camera0 = arv_camera_new("Fake_1", &err_); if (err_) { throw std::runtime_error(err_->message); } auto fake_device0 = arv_camera_get_device(fake_camera0); devices_[0].device_ = fake_device0; - devices_[0].dev_id_= "fake_0"; + devices_[0].dev_id_ = "fake_0"; devices_[0].camera_ = fake_camera0; - if (num_sensor_==2){ + if (num_sensor_ == 2) { // aravis only provide on ARV_FAKE_DEVICE_ID https://github.com/Sensing-Dev/aravis/blob/main/src/arvfakeinterface.c - auto fake_camera1 = arv_camera_new ("Fake_1", &err_); + auto fake_camera1 = arv_camera_new("Fake_1", &err_); if (err_) { throw std::runtime_error(err_->message); } auto fake_device1 = arv_camera_get_device(fake_camera1); devices_[1].device_ = fake_device1; - devices_[1].dev_id_= "fake_1"; + devices_[1].dev_id_ = "fake_1"; devices_[1].camera_ = fake_camera1; } // Config fake cameras - for (int i = 0;i< num_sensor_;i++){ + for (int i = 0; i < num_sensor_; i++) { // setting the params if it is not zero log::info("Width {}, Height {} PixelFormat {}...", width, height, pixel_format_); - arv_device_set_integer_feature_value (devices_[i].device_, "Width", width, &err_); + arv_device_set_integer_feature_value(devices_[i].device_, "Width", width, &err_); if (err_) { throw std::runtime_error(err_->message); } - arv_device_set_integer_feature_value (devices_[i].device_, "Height", height, &err_); + arv_device_set_integer_feature_value(devices_[i].device_, "Height", height, &err_); if (err_) { throw std::runtime_error(err_->message); } - arv_device_set_float_feature_value (devices_[i].device_, "AcquisitionFrameRate",fps, &err_); + arv_device_set_float_feature_value(devices_[i].device_, "AcquisitionFrameRate", fps, &err_); if (err_) { throw std::runtime_error(err_->message); } - if (pixel_format_ != "Mono8"){ + if (pixel_format_ != "Mono8") { arv_device_set_string_feature_value(devices_[i].device_, "PixelFormat", pixel_format.c_str(), &err_); if (err_) { throw std::runtime_error(err_->message); } } - devices_[i].u3v_payload_size_ = arv_device_get_integer_feature_value (devices_[i].device_, "PayloadSize", &err_); + devices_[i].u3v_payload_size_ = arv_device_get_integer_feature_value(devices_[i].device_, "PayloadSize", &err_); if (err_) { throw std::runtime_error(err_->message); } - auto px =arv_device_get_integer_feature_value(devices_[i].device_, "PixelFormat", &err_); + auto px = arv_device_get_integer_feature_value(devices_[i].device_, "PixelFormat", &err_); if (err_) { throw std::runtime_error(err_->message); } @@ -580,41 +575,39 @@ class U3V { if (err_) { throw std::runtime_error(err_->message); } - struct rawHeader header= { 1, width, height, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, - width, height, width, height, static_cast(fps), px, 0}; + struct rawHeader header = {1, width, height, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + width, height, width, height, static_cast(fps), px, 0}; devices_[i].header_info_ = header; devices_[i].image_payload_size_ = devices_[i].u3v_payload_size_; - devices_[i].frame_count_ = 0; - + devices_[i].frame_count_ = 0; } } - void open_real_devices(int32_t num_detected_device, int32_t num_usb_to_open, char* dev_id){ + void open_real_devices(int32_t num_detected_device, int32_t num_usb_to_open, char *dev_id) { int index_on_detected_device = 0; int index_on_opened_device = 0; - while (index_on_detected_device < num_detected_device && index_on_opened_device < num_usb_to_open){ + while (index_on_detected_device < num_detected_device && index_on_opened_device < num_usb_to_open) { - const char* device_protocol = arv_get_device_protocol(index_on_detected_device); - if (strcmp(device_protocol, "USB3Vision") == 0){ + const char *device_protocol = arv_get_device_protocol(index_on_detected_device); + if (strcmp(device_protocol, "USB3Vision") == 0) { - if (dev_id != nullptr && dev_id == arv_get_device_id (index_on_detected_device)){ + if (dev_id != nullptr && dev_id == arv_get_device_id(index_on_detected_device)) { /* if device id is specified TODO: dev_id may be more than 1 */ devices_[index_on_opened_device].dev_id_ = dev_id; - } - else{ + } else { /* if device id is not specified */ - devices_[index_on_opened_device].dev_id_ = arv_get_device_id (index_on_detected_device); + devices_[index_on_opened_device].dev_id_ = arv_get_device_id(index_on_detected_device); } log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "DeviceID", devices_[index_on_opened_device].dev_id_); devices_[index_on_opened_device].device_ = arv_open_device(devices_[index_on_opened_device].dev_id_, &err_); - if (err_ ) { + if (err_) { throw std::runtime_error(err_->message); } @@ -623,16 +616,16 @@ class U3V { } pixel_format_ = arv_device_get_string_feature_value(devices_[index_on_opened_device].device_, "PixelFormat", &err_); - if (err_ ) { + if (err_) { log::error(err_->message); err_ = nullptr; - }else{ + } else { log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "PixelFormat", pixel_format_); } - + // Here PayloadSize is the one for U3V data devices_[index_on_opened_device].u3v_payload_size_ = arv_device_get_integer_feature_value(devices_[index_on_opened_device].device_, "PayloadSize", &err_); - if (err_ ) { + if (err_) { throw std::runtime_error(err_->message); } log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "PayloadSize", devices_[index_on_opened_device].u3v_payload_size_); @@ -648,36 +641,36 @@ class U3V { } // check it the device is gendc mode =============================== - if (is_gendc_){ - const char * streaming_mode; + if (is_gendc_) { + const char *streaming_mode; streaming_mode = arv_device_get_string_feature_value(devices_[index_on_opened_device].device_, "GenDCStreamingMode", &err_); if (err_) { throw std::runtime_error(err_->message); } - is_gendc_ &= (strcmp(streaming_mode, "On")==0); + is_gendc_ &= (strcmp(streaming_mode, "On") == 0); } // Some type of U3V Camera supports Frame count generated by its device - const char* device_vender_name; + const char *device_vender_name; device_vender_name = arv_device_get_string_feature_value(devices_[index_on_opened_device].device_, "DeviceVendorName", &err_); if (err_) { log::error(err_->message); err_ = nullptr; - }else{ - if (strcmp(device_vender_name, "Sony Semiconductor Solutions Corporation")==0){ - const char* device_model_name; + } else { + if (strcmp(device_vender_name, "Sony Semiconductor Solutions Corporation") == 0) { + const char *device_model_name; device_model_name = arv_device_get_string_feature_value(devices_[index_on_opened_device].device_, "DeviceModelName", &err_); if (err_) { log::error(err_->message); err_ = nullptr; - }else{ - if (strcmp(device_model_name, " ")==0){ - is_param_integer_ = true; + } else { + if (strcmp(device_model_name, " ") == 0) { + is_param_integer_ = true; frame_count_method_ = FrameCountMethod::TIMESTAMP; - arv_uv_device_set_usb_mode(devices_[index_on_opened_device].device_, ARV_UV_USB_MODE_SYNC); //hotfix for v1.0 + arv_uv_device_set_usb_mode(devices_[index_on_opened_device].device_, ARV_UV_USB_MODE_SYNC); // hotfix for v1.0 } } - if (is_gendc_){ + if (is_gendc_) { frame_count_method_ = FrameCountMethod::TYPESPECIFIC3; order_filp_ = true; } @@ -685,51 +678,51 @@ class U3V { } log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "frame_count method is ", - frame_count_method_ == FrameCountMethod::TIMESTAMP ? "Timestamp": - frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? "TypeSpecific" : "Unavailabe"); + frame_count_method_ == FrameCountMethod::TIMESTAMP ? "Timestamp" : + frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? "TypeSpecific" : + "Unavailabe"); // Check each parameters for GenDC device ========================== int group_id = 0; - if (is_gendc_){ + if (is_gendc_) { log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "GenDC", "Available"); uint64_t gendc_desc_size = 0; - char* buffer = reinterpret_cast(arv_device_dup_register_feature_value(devices_[index_on_opened_device].device_,"GenDCDescriptor", &gendc_desc_size, &err_ )); + char *buffer = reinterpret_cast(arv_device_dup_register_feature_value(devices_[index_on_opened_device].device_, "GenDCDescriptor", &gendc_desc_size, &err_)); if (err_) { throw std::runtime_error(err_->message); } - if(isGenDC(buffer)){ - gendc_descriptor_= ContainerHeader(buffer); + if (isGenDC(buffer)) { + gendc_descriptor_ = ContainerHeader(buffer); std::tuple data_comp_and_part = gendc_descriptor_.getFirstAvailableDataOffset(true); - if (std::get<0>(data_comp_and_part) == -1){ + if (std::get<0>(data_comp_and_part) == -1) { devices_[index_on_opened_device].is_data_image_ = false; data_comp_and_part = gendc_descriptor_.getFirstAvailableDataOffset(false); - if (std::get<0>(data_comp_and_part) == -1){ + if (std::get<0>(data_comp_and_part) == -1) { throw std::runtime_error("None of the data in GenDC is available\n"); } - }else{ + } else { devices_[index_on_opened_device].is_data_image_ = true; } devices_[index_on_opened_device].data_offset_ = gendc_descriptor_.getDataOffset(std::get<0>(data_comp_and_part), std::get<1>(data_comp_and_part)); devices_[index_on_opened_device].image_payload_size_ = gendc_descriptor_.getDataSize(std::get<0>(data_comp_and_part), std::get<1>(data_comp_and_part)); - if (frame_count_method_ == FrameCountMethod::TYPESPECIFIC3){ + if (frame_count_method_ == FrameCountMethod::TYPESPECIFIC3) { devices_[index_on_opened_device].framecount_offset_ = gendc_descriptor_.getOffsetFromTypeSpecific(std::get<0>(data_comp_and_part), std::get<1>(data_comp_and_part), 3, 0); } int32_t image_component_index = gendc_descriptor_.getFirstComponentIndexByTypeID(ComponentIDIntensity); - if (image_component_index == -1){ + if (image_component_index == -1) { throw ::std::runtime_error("No available component found"); } ComponentHeader image_component = gendc_descriptor_.getComponentByIndex(image_component_index); group_id = gendc_descriptor_.getComponentByIndex(image_component_index).getGroupID(); } free(buffer); - }else{ + } else { devices_[index_on_opened_device].data_offset_ = 0; devices_[index_on_opened_device].image_payload_size_ = devices_[index_on_opened_device].u3v_payload_size_; log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "GenDC", "Not Supported"); } - // Set Device Info ================================================= { int32_t wi = arv_device_get_integer_feature_value(devices_[index_on_opened_device].device_, "Width", &err_); @@ -741,7 +734,7 @@ class U3V { throw std::runtime_error(err_->message); } double fps = 0.0; - if (arv_device_is_feature_available(devices_[index_on_opened_device].device_, "AcquisitionFrameRate", &err_)){ + if (arv_device_is_feature_available(devices_[index_on_opened_device].device_, "AcquisitionFrameRate", &err_)) { fps = arv_device_get_float_feature_value(devices_[index_on_opened_device].device_, "AcquisitionFrameRate", &err_); } if (err_) { @@ -750,35 +743,34 @@ class U3V { log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "Width", wi); log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "Height", hi); - int32_t px = arv_device_get_integer_feature_value (devices_[index_on_opened_device].device_, "PixelFormat", &err_); - if (err_ || px == 0){ + int32_t px = arv_device_get_integer_feature_value(devices_[index_on_opened_device].device_, "PixelFormat", &err_); + if (err_ || px == 0) { log::info("The pixel format is not supported for header info"); err_ = nullptr; px = 0; } - devices_[index_on_opened_device].header_info_ = { 1, wi, hi, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, - wi, hi, wi, hi, static_cast(fps), px, group_id - }; + devices_[index_on_opened_device].header_info_ = {1, wi, hi, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + wi, hi, wi, hi, static_cast(fps), px, group_id}; } - if (index_on_opened_device == 0 && arv_device_is_feature_available(devices_[index_on_opened_device].device_, "OperationMode", &err_)){ + if (index_on_opened_device == 0 && arv_device_is_feature_available(devices_[index_on_opened_device].device_, "OperationMode", &err_)) { if (err_) { throw std::runtime_error(err_->message); } - const char* operation_mode_in_string; + const char *operation_mode_in_string; operation_mode_in_string = arv_device_get_string_feature_value(devices_[index_on_opened_device].device_, "OperationMode", &err_); if (err_) { throw std::runtime_error(err_->message); } - if (strcmp(operation_mode_in_string, "Came2USB1")==0){ + if (strcmp(operation_mode_in_string, "Came2USB1") == 0) { operation_mode_ = OperationMode::Came2USB1; - }else if (strcmp(operation_mode_in_string, "Came1USB1")==0){ + } else if (strcmp(operation_mode_in_string, "Came1USB1") == 0) { operation_mode_ = OperationMode::Came1USB1; - }else if (strcmp(operation_mode_in_string, "Came2USB2")==0){ + } else if (strcmp(operation_mode_in_string, "Came2USB2") == 0) { operation_mode_ = OperationMode::Came2USB2; - }else if (strcmp(operation_mode_in_string, "Came1USB2")==0){ + } else if (strcmp(operation_mode_in_string, "Came1USB2") == 0) { operation_mode_ = OperationMode::Came1USB2; num_usb_to_open += 1; devices_.resize(num_usb_to_open); @@ -787,7 +779,7 @@ class U3V { log::info("\tDevice/USB {}::{} : {}", index_on_opened_device, "OperationMode", operation_mode_in_string); } index_on_opened_device += 1; - }else{ + } else { log::info("\tDevice/USB {}::{} : {} ... skipped", index_on_opened_device, "device protocol", device_protocol); } @@ -795,19 +787,19 @@ class U3V { } } - void create_stream_and_start_acquisition(bool specific_device_to_flip_order){ + void create_stream_and_start_acquisition(bool specific_device_to_flip_order) { /* - * ion-kit starts the acquisition before stream creation This is a tentative fix only in ion-kit due to hardware issue - * In aravis, the acquisition should be done afterward. Since this maps better with GenAPI, where buffers - * must be pushed to DataStream objectsbefore DataStream acquisition is started. - * refer to https://github.com/AravisProject/aravis/blob/2ebaa8661761ea4bbc4df878aa67b4a9e1a9a3b9/docs/reference/aravis/porting-0.10.md - */ - if (specific_device_to_flip_order){ + * ion-kit starts the acquisition before stream creation This is a tentative fix only in ion-kit due to hardware issue + * In aravis, the acquisition should be done afterward. Since this maps better with GenAPI, where buffers + * must be pushed to DataStream objectsbefore DataStream acquisition is started. + * refer to https://github.com/AravisProject/aravis/blob/2ebaa8661761ea4bbc4df878aa67b4a9e1a9a3b9/docs/reference/aravis/porting-0.10.md + */ + if (specific_device_to_flip_order) { log::info("Execute AcquisitionStart before create stream on this device."); command_acquisition_mode_contd_and_start(); } - //start streaming after AcquisitionStart - for (auto i=0; imessage); @@ -817,27 +809,26 @@ class U3V { } } - if (! specific_device_to_flip_order){ + if (!specific_device_to_flip_order) { command_acquisition_mode_contd_and_start(); } } - void allocate_buffers(){ - for (auto i=0; i &bufs, int timeout_us){ + void sync_frame_count(std::vector &bufs, int timeout_us) { uint32_t max_cnt = 0; while (true) { // Update max_cnt @@ -868,11 +859,8 @@ class U3V { log::error("pop_buffer failed when sync frame due to timeout ({}s)", timeout_us * 1e-6f); throw ::std::runtime_error("buffer is null"); } - devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) - : -1; + devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) : + -1; i == 0 ? log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[i].frame_count_, "") : log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[i].frame_count_); @@ -881,7 +869,7 @@ class U3V { } } - void consume_old_buffer(std::vector &bufs, int timeout_us = 3 * 1000 * 1000){ + void consume_old_buffer(std::vector &bufs, int timeout_us = 3 * 1000 * 1000) { std::vector N_output_buffers(num_sensor_); for (auto i = 0; i < num_sensor_; ++i) { int32_t num_input_buffer; @@ -895,16 +883,12 @@ class U3V { log::error("pop_buffer(L2) failed due to timeout ({}s)", timeout_us * 1e-6f); throw ::std::runtime_error("buffer is null"); } - devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) - : -1; + devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) : + -1; i == 0 ? log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20}) [skipped for realtime display]", devices_[i].frame_count_, "") : log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20}) [skipped for realtime display]", "", devices_[i].frame_count_); arv_stream_push_buffer(devices_[i].stream_, bufs[i]); - } } } @@ -959,20 +943,20 @@ class U3V { arv_shutdown_t arv_shutdown; - arv_camera_new_t arv_camera_new; + arv_camera_new_t arv_camera_new; arv_camera_get_device_t arv_camera_get_device; arv_camera_create_stream_t arv_camera_create_stream; arv_fake_device_new_t arv_fake_device_new; arv_enable_interface_t arv_enable_interface; - arv_set_fake_camera_genicam_filename_t arv_set_fake_camera_genicam_filename; + arv_set_fake_camera_genicam_filename_t arv_set_fake_camera_genicam_filename; arv_fake_device_get_fake_camera_t arv_fake_device_get_fake_camera; arv_uv_device_set_usb_mode_t arv_uv_device_set_usb_mode; static std::map> instances_; - int32_t num_sensor_; //SENSOR NUMBER + int32_t num_sensor_; // SENSOR NUMBER DynamicModule gobject_; DynamicModule aravis_; @@ -985,7 +969,7 @@ class U3V { int32_t operation_mode_; uint32_t frame_cnt_; - int32_t device_idx_; //USB DEVICE INDEX + int32_t device_idx_; // USB DEVICE INDEX int frame_count_method_; // genDC @@ -993,84 +977,78 @@ class U3V { std::string pixel_format_; - std::vector devices_; //USB DEVICE + std::vector devices_; // USB DEVICE - std::vector > buffers_; + std::vector> buffers_; bool disposed_; bool sim_mode_; bool order_filp_; -}; // class U3V +}; // class U3V -std::map> U3V::instances_; +std::map> U3V::instances_; - -class U3VFakeCam : public U3V{ +class U3VFakeCam : public U3V { public: - static U3V & get_instance(const std::string& id, - int32_t num_sensor, - int32_t width = 640, - int32_t height = 480, - float_t fps = 25.0, - const std::string& pixel_format = "Mono8" - ) - { + static U3V &get_instance(const std::string &id, + int32_t num_sensor, + int32_t width = 640, + int32_t height = 480, + float_t fps = 25.0, + const std::string &pixel_format = "Mono8") { if (instances_.count(id) == 0) { ion::log::info("Create U3VFakeCam U3V instance: {}", id); - instances_[id] = std::unique_ptr(new U3VFakeCam(num_sensor, width, height, fps, pixel_format)); + instances_[id] = std::unique_ptr(new U3VFakeCam(num_sensor, width, height, fps, pixel_format)); } return *instances_[id].get(); } - void get(std::vector>& outs) override { + void get(std::vector> &outs) override { auto timeout_us = 30 * 1000 * 1000; std::vector bufs(num_sensor_); - for (int i = 0;i< num_sensor_;i++){ - auto size = devices_[i].u3v_payload_size_; - arv_stream_push_buffer (devices_[i].stream_, arv_buffer_new_allocate (size)); - bufs[i] = arv_stream_timeout_pop_buffer (devices_[i].stream_, timeout_us); - if (bufs[i] == nullptr){ - log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us*1e-6f); - throw ::std::runtime_error("Buffer is null"); - } - devices_[i].frame_count_ += 1; - memcpy(outs[i].data(), arv_buffer_get_part_data(bufs[i], 0, nullptr), size); + for (int i = 0; i < num_sensor_; i++) { + auto size = devices_[i].u3v_payload_size_; + arv_stream_push_buffer(devices_[i].stream_, arv_buffer_new_allocate(size)); + bufs[i] = arv_stream_timeout_pop_buffer(devices_[i].stream_, timeout_us); + if (bufs[i] == nullptr) { + log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us * 1e-6f); + throw ::std::runtime_error("Buffer is null"); + } + devices_[i].frame_count_ += 1; + memcpy(outs[i].data(), arv_buffer_get_part_data(bufs[i], 0, nullptr), size); } } private: - U3VFakeCam(int32_t num_sensor, int32_t width, int32_t height , float_t fps, const std::string & pixel_format, char* dev_id = nullptr) - : U3V(num_sensor, false, false, true, width, height , fps, pixel_format, nullptr){ + U3VFakeCam(int32_t num_sensor, int32_t width, int32_t height, float_t fps, const std::string &pixel_format, char *dev_id = nullptr) + : U3V(num_sensor, false, false, true, width, height, fps, pixel_format, nullptr) { open_fake_devices(width, height, fps, pixel_format); // Start streaming and start acquisition - for (auto i=0; i(new U3VRealCam(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height, fps, pixel_format)); @@ -1079,23 +1057,23 @@ class U3VRealCam: public U3V{ return *instances_[id].get(); } - void get(std::vector>& outs) override{ + void get(std::vector> &outs) override { auto timeout_us = 30 * 1000 * 1000; int32_t num_device = devices_.size(); - if (sim_mode_){ + if (sim_mode_) { std::vector bufs(num_device); - for (int i = 0;i< num_device;i++){ - auto size = devices_[i].u3v_payload_size_; - arv_stream_push_buffer (devices_[i].stream_, arv_buffer_new_allocate (size)); - bufs[i] = arv_stream_timeout_pop_buffer (devices_[i].stream_, timeout_us); - if (bufs[i] == nullptr){ - log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us*1e-6f); - throw ::std::runtime_error("Buffer is null"); - } - devices_[i].frame_count_ += 1; - memcpy(outs[i].data(), arv_buffer_get_part_data(bufs[i], 0, nullptr), size); + for (int i = 0; i < num_device; i++) { + auto size = devices_[i].u3v_payload_size_; + arv_stream_push_buffer(devices_[i].stream_, arv_buffer_new_allocate(size)); + bufs[i] = arv_stream_timeout_pop_buffer(devices_[i].stream_, timeout_us); + if (bufs[i] == nullptr) { + log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us * 1e-6f); + throw ::std::runtime_error("Buffer is null"); } - }else { + devices_[i].frame_count_ += 1; + memcpy(outs[i].data(), arv_buffer_get_part_data(bufs[i], 0, nullptr), size); + } + } else { std::vector bufs(num_device); @@ -1104,28 +1082,25 @@ class U3VRealCam: public U3V{ // if aravis output queue length is more than N (where N > 1) for all devices, pop all N-1 buffer if (realtime_display_mode_) { - consume_old_buffer(bufs,timeout_us); + consume_old_buffer(bufs, timeout_us); } // get the first buffer for each stream - for (auto i = 0; i (get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) - : -1; + devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) : + -1; i == 0 ? - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[i].frame_count_, "") : - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[i].frame_count_); + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[i].frame_count_, "") : + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[i].frame_count_); } if (frame_sync_) { - sync_frame_count(bufs,timeout_us); + sync_frame_count(bufs, timeout_us); } for (int i = 0; i < num_device; ++i) { @@ -1142,77 +1117,72 @@ class U3VRealCam: public U3V{ // if aravis output queue length is more than N (where N > 1) for all devices, pop all N-1 buffer if (realtime_display_mode_) { - consume_old_buffer(bufs,timeout_us); + consume_old_buffer(bufs, timeout_us); } - //first buffer + // first buffer device_idx_ = (device_idx_ + 1) >= num_device ? 0 : device_idx_ + 1; bufs[device_idx_] = arv_stream_timeout_pop_buffer(devices_[device_idx_].stream_, 30 * 1000 * 1000); if (bufs[device_idx_] == nullptr) { log::error("pop_buffer(L4) failed due to timeout ({}s)", timeout_us * 1e-6f); throw ::std::runtime_error("buffer is null"); } - devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) - : -1; - latest_cnt = devices_[device_idx_].frame_count_; - device_idx_ == 0 ? - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[device_idx_].frame_count_); - - int internal_count = 0; - int max_internal_count = 1000; - - while (frame_cnt_ >= latest_cnt) { - arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); - bufs[device_idx_] = arv_stream_timeout_pop_buffer (devices_[device_idx_].stream_, 30 * 1000 * 1000); - if (bufs[device_idx_] == nullptr){ - log::error("pop_buffer(L4) failed due to timeout ({}s)", timeout_us*1e-6f); - throw ::std::runtime_error("buffer is null"); - } - devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) - : -1; + devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) : + -1; latest_cnt = devices_[device_idx_].frame_count_; device_idx_ == 0 ? log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[device_idx_].frame_count_); - if (internal_count++ > max_internal_count){ - log::error("pop_buffer(L9) The sequential invalid buffer is more than {}; Stop the pipeline.", max_internal_count); - throw ::std::runtime_error("Invalid framecount"); + int internal_count = 0; + int max_internal_count = 1000; + + while (frame_cnt_ >= latest_cnt) { + arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); + bufs[device_idx_] = arv_stream_timeout_pop_buffer(devices_[device_idx_].stream_, 30 * 1000 * 1000); + if (bufs[device_idx_] == nullptr) { + log::error("pop_buffer(L4) failed due to timeout ({}s)", timeout_us * 1e-6f); + throw ::std::runtime_error("buffer is null"); + } + devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) : + -1; + latest_cnt = devices_[device_idx_].frame_count_; + device_idx_ == 0 ? + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[device_idx_].frame_count_); + + if (internal_count++ > max_internal_count) { + log::error("pop_buffer(L9) The sequential invalid buffer is more than {}; Stop the pipeline.", max_internal_count); + throw ::std::runtime_error("Invalid framecount"); + } } - } - frame_cnt_ = latest_cnt; - auto sz = (std::min)(devices_[device_idx_].image_payload_size_, static_cast(outs[0].size_in_bytes())); - ::memcpy(outs[0].data(), arv_buffer_get_part_data(bufs[device_idx_], 0, nullptr), sz); - arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); + frame_cnt_ = latest_cnt; + auto sz = (std::min)(devices_[device_idx_].image_payload_size_, static_cast(outs[0].size_in_bytes())); + ::memcpy(outs[0].data(), arv_buffer_get_part_data(bufs[device_idx_], 0, nullptr), sz); + arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); - log::trace("Obtained Frame from USB{}: {}", device_idx_, frame_cnt_); + log::trace("Obtained Frame from USB{}: {}", device_idx_, frame_cnt_); } } } + private: - U3VRealCam(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height , float_t fps, const std::string & pixel_format, char* dev_id = nullptr) - : U3V(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height , fps, pixel_format, nullptr){ + U3VRealCam(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height, float_t fps, const std::string &pixel_format, char *dev_id = nullptr) + : U3V(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height, fps, pixel_format, nullptr) { // check if the camera is available arv_update_device_list(); - auto num_device = arv_get_n_devices (); - if (num_device == 0){ + auto num_device = arv_get_n_devices(); + if (num_device == 0) { log::warn("Fallback to simulation mode: Could not find camera"); sim_mode_ = true; } - if (sim_mode_){ + if (sim_mode_) { open_fake_devices(width, height, fps, pixel_format); // Start streaming and start acquisition create_stream_and_start_acquisition(order_filp_); - }else{ + } else { // Real Camera validate_user_input(num_device, dev_id); open_real_devices(num_device, num_sensor_, dev_id); @@ -1222,20 +1192,17 @@ class U3VRealCam: public U3V{ }; }; - -class U3VGenDC: public U3V{ +class U3VGenDC : public U3V { public: - static U3V & get_instance(const std::string& id, - int32_t num_sensor, - bool frame_sync, - bool realtime_display_mode, - bool sim_mode = false, - int32_t width = 640, - int32_t height = 480, - float_t fps = 25.0, - const std::string& pixel_format = "Mono8" - ) - { + static U3V &get_instance(const std::string &id, + int32_t num_sensor, + bool frame_sync, + bool realtime_display_mode, + bool sim_mode = false, + int32_t width = 640, + int32_t height = 480, + float_t fps = 25.0, + const std::string &pixel_format = "Mono8") { if (instances_.count(id) == 0) { ion::log::info("Create GenDC instance: {}", id); instances_[id] = std::unique_ptr(new U3VGenDC(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height, fps, pixel_format)); @@ -1244,44 +1211,40 @@ class U3VGenDC: public U3V{ return *instances_[id].get(); } - - void get(std::vector& outs) override{ + void get(std::vector &outs) override { // TODO: Is 3 second fine? auto timeout_us = 3 * 1000 * 1000; int32_t num_device = devices_.size(); std::vector bufs(num_device); - if (sim_mode_){ + if (sim_mode_) { std::vector bufs(num_sensor_); - for (int i = 0;i< num_sensor_;i++){ - auto size = devices_[i].u3v_payload_size_; - arv_stream_push_buffer (devices_[i].stream_, arv_buffer_new_allocate (size)); - bufs[i] = arv_stream_timeout_pop_buffer (devices_[i].stream_, timeout_us); - if (bufs[i] == nullptr){ - log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us*1e-6f); - throw ::std::runtime_error("Buffer is null"); - } - devices_[i].frame_count_ += 1; - memcpy(outs[i], arv_buffer_get_part_data(bufs[i], 0, nullptr), size);} - } - else if (operation_mode_ == OperationMode::Came2USB2 || operation_mode_ == OperationMode::Came1USB1){ + for (int i = 0; i < num_sensor_; i++) { + auto size = devices_[i].u3v_payload_size_; + arv_stream_push_buffer(devices_[i].stream_, arv_buffer_new_allocate(size)); + bufs[i] = arv_stream_timeout_pop_buffer(devices_[i].stream_, timeout_us); + if (bufs[i] == nullptr) { + log::error("pop_buffer(L1) failed due to timeout ({}s)", timeout_us * 1e-6f); + throw ::std::runtime_error("Buffer is null"); + } + devices_[i].frame_count_ += 1; + memcpy(outs[i], arv_buffer_get_part_data(bufs[i], 0, nullptr), size); + } + } else if (operation_mode_ == OperationMode::Came2USB2 || operation_mode_ == OperationMode::Came1USB1) { - if (realtime_display_mode_){ - consume_old_buffer(bufs,timeout_us); + if (realtime_display_mode_) { + consume_old_buffer(bufs, timeout_us); } // get the first buffer for each stream - for (auto i = 0; i< devices_.size(); ++i) { - bufs[i] = arv_stream_timeout_pop_buffer (devices_[i].stream_, timeout_us); - if (bufs[i] == nullptr){ - log::error("pop_buffer(L5) failed due to timeout ({}s)", timeout_us*1e-6f); + for (auto i = 0; i < devices_.size(); ++i) { + bufs[i] = arv_stream_timeout_pop_buffer(devices_[i].stream_, timeout_us); + if (bufs[i] == nullptr) { + log::error("pop_buffer(L5) failed due to timeout ({}s)", timeout_us * 1e-6f); throw ::std::runtime_error("buffer is null"); } - devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) - : -1; + devices_[i].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[i], devices_[i])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[i]) & 0x00000000FFFFFFFF) : + -1; i == 0 ? log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[i].frame_count_, "") : @@ -1289,10 +1252,10 @@ class U3VGenDC: public U3V{ } if (frame_sync_) { - sync_frame_count(bufs,timeout_us); + sync_frame_count(bufs, timeout_us); } - for (int i = 0; i < num_sensor_; ++i){ + for (int i = 0; i < num_sensor_; ++i) { ::memcpy(outs[i], arv_buffer_get_data(bufs[i], nullptr), devices_[i].u3v_payload_size_); // ::memcpy(outs[i*num_sensor_+1], &(devices_[i].header_info_), sizeof(ion::bb::image_io::rawHeader)); arv_stream_push_buffer(devices_[i].stream_, bufs[i]); @@ -1303,22 +1266,19 @@ class U3VGenDC: public U3V{ int32_t min_frame_device_idx = 0; // if aravis output queue length is more than N (where N > 1) for all devices, pop all N-1 buffer - if (realtime_display_mode_){ - consume_old_buffer(bufs,timeout_us); + if (realtime_display_mode_) { + consume_old_buffer(bufs, timeout_us); } - //first buffer - device_idx_ = (device_idx_+1) >= num_device ? 0 : device_idx_+1; - bufs[device_idx_] = arv_stream_timeout_pop_buffer (devices_[device_idx_].stream_, 30 * 1000 * 1000); - if (bufs[device_idx_] == nullptr){ - log::error("pop_buffer(L4) failed due to timeout ({}s)", timeout_us*1e-6f); + // first buffer + device_idx_ = (device_idx_ + 1) >= num_device ? 0 : device_idx_ + 1; + bufs[device_idx_] = arv_stream_timeout_pop_buffer(devices_[device_idx_].stream_, 30 * 1000 * 1000); + if (bufs[device_idx_] == nullptr) { + log::error("pop_buffer(L4) failed due to timeout ({}s)", timeout_us * 1e-6f); throw ::std::runtime_error("buffer is null"); } - devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) - : -1; + devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) : + -1; latest_cnt = devices_[device_idx_].frame_count_; device_idx_ == 0 ? log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : @@ -1327,28 +1287,25 @@ class U3VGenDC: public U3V{ int internal_count = 0; int max_internal_count = 1000; - while (frame_cnt_ >= latest_cnt) { - arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); - auto timeout2_us = 30 * 1000 * 1000; - bufs[device_idx_] = arv_stream_timeout_pop_buffer (devices_[device_idx_].stream_, timeout2_us); - if (bufs[device_idx_] == nullptr){ - log::error("pop_buffer(L8) failed due to timeout ({}s)", timeout2_us*1e-6f); - throw ::std::runtime_error("buffer is null"); - } - devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 - ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) - : frame_count_method_ == FrameCountMethod::TIMESTAMP - ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) - : -1; - device_idx_ == 0 ? - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : - log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[device_idx_].frame_count_); - latest_cnt = devices_[device_idx_].frame_count_; - if (internal_count++ > max_internal_count){ - log::error("pop_buffer(L10) The sequential invalid buffer is more than {}; Stop the pipeline.", max_internal_count); - throw ::std::runtime_error("Invalid framecount"); - } + while (frame_cnt_ >= latest_cnt) { + arv_stream_push_buffer(devices_[device_idx_].stream_, bufs[device_idx_]); + auto timeout2_us = 30 * 1000 * 1000; + bufs[device_idx_] = arv_stream_timeout_pop_buffer(devices_[device_idx_].stream_, timeout2_us); + if (bufs[device_idx_] == nullptr) { + log::error("pop_buffer(L8) failed due to timeout ({}s)", timeout2_us * 1e-6f); + throw ::std::runtime_error("buffer is null"); + } + devices_[device_idx_].frame_count_ = frame_count_method_ == FrameCountMethod::TYPESPECIFIC3 ? static_cast(get_frame_count_from_genDC_descriptor(bufs[device_idx_], devices_[device_idx_])) : frame_count_method_ == FrameCountMethod::TIMESTAMP ? static_cast(arv_buffer_get_timestamp(bufs[device_idx_]) & 0x00000000FFFFFFFF) : + -1; + device_idx_ == 0 ? + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", devices_[device_idx_].frame_count_, "") : + log::trace("All-Popped Frames (USB0, USB1)=({:20}, {:20})", "", devices_[device_idx_].frame_count_); + latest_cnt = devices_[device_idx_].frame_count_; + if (internal_count++ > max_internal_count) { + log::error("pop_buffer(L10) The sequential invalid buffer is more than {}; Stop the pipeline.", max_internal_count); + throw ::std::runtime_error("Invalid framecount"); } + } frame_cnt_ = latest_cnt; ::memcpy(outs[0], arv_buffer_get_data(bufs[device_idx_], nullptr), devices_[device_idx_].u3v_payload_size_); @@ -1357,22 +1314,23 @@ class U3VGenDC: public U3V{ log::trace("Obtained Frame from USB{}: {}", device_idx_, frame_cnt_); } } + private: - U3VGenDC(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height , float_t fps, const std::string & pixel_format, char* dev_id = nullptr) - : U3V(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height , fps, pixel_format, nullptr){ + U3VGenDC(int32_t num_sensor, bool frame_sync, bool realtime_display_mode, bool sim_mode, int32_t width, int32_t height, float_t fps, const std::string &pixel_format, char *dev_id = nullptr) + : U3V(num_sensor, frame_sync, realtime_display_mode, sim_mode, width, height, fps, pixel_format, nullptr) { // check if the camera is available arv_update_device_list(); - auto num_device = arv_get_n_devices (); - if (num_device == 0){ + auto num_device = arv_get_n_devices(); + if (num_device == 0) { log::warn("Fallback to simulation mode: Could not find camera"); sim_mode_ = true; } - if (sim_mode_){ + if (sim_mode_) { open_fake_devices(width, height, fps, pixel_format); // Start streaming and start acquisition create_stream_and_start_acquisition(order_filp_); - }else{ + } else { // Real Camera validate_user_input(num_device, dev_id); open_real_devices(num_device, num_sensor_, dev_id); @@ -1380,23 +1338,20 @@ class U3VGenDC: public U3V{ allocate_buffers(); } }; - }; } // namespace image_io } // namespace bb } // namespace ion -extern "C" -int ION_EXPORT u3v_dispose(const char *id) { +extern "C" int ION_EXPORT u3v_dispose(const char *id) { ion::bb::image_io::U3V::release_instance(id); return 0; } int u3v_camera_frame_count( - const std::string& id, int32_t num_sensor, bool frame_sync, bool realtime_display_mode, - halide_buffer_t* out) -{ + const std::string &id, int32_t num_sensor, bool frame_sync, bool realtime_display_mode, + halide_buffer_t *out) { try { auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, num_sensor, frame_sync, realtime_display_mode)); std::vector obufs{out->host}; @@ -1404,8 +1359,7 @@ int u3v_camera_frame_count( out->dim[0].min = 0; out->dim[0].extent = num_sensor; return 0; - } - else { + } else { u3v.get_frame_count(obufs); } @@ -1420,28 +1374,26 @@ int u3v_camera_frame_count( } } -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_camera1( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_camera1( bool frame_sync, bool realtime_display_mode, double gain0, double exposure0, - halide_buffer_t * id_buf, halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, - halide_buffer_t * out0) -{ + halide_buffer_t *id_buf, halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, + halide_buffer_t *out0) { using namespace Halide; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 1, frame_sync, realtime_display_mode)); if (out0->is_bounds_query()) { - //bounds query + // bounds query return 0; - }else{ + } else { // set gain & exposure u3v.set_gain(0, gain_key, gain0); u3v.set_exposure(0, exposure_key, exposure0); - std::vector > obufs{Halide::Buffer<>(*out0)}; + std::vector> obufs{Halide::Buffer<>(*out0)}; u3v.get(obufs); } @@ -1456,29 +1408,27 @@ int ION_EXPORT ion_bb_image_io_u3v_camera1( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_camera1); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_camera2( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_camera2( bool frame_sync, bool realtime_display_mode, double gain0, double gain1, double exposure0, double exposure1, - halide_buffer_t * id_buf, halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, - halide_buffer_t * out0, halide_buffer_t * out1) -{ + halide_buffer_t *id_buf, halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, + halide_buffer_t *out0, halide_buffer_t *out1) { using namespace Halide; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 2, frame_sync, realtime_display_mode)); if (out0->is_bounds_query() || out1->is_bounds_query()) { - //bounds query + // bounds query return 0; - }else{ + } else { // set gain & exposure u3v.set_gain(0, gain_key, gain0); u3v.set_gain(1, gain_key, gain1); u3v.set_exposure(0, exposure_key, exposure0); u3v.set_exposure(1, exposure_key, exposure1); - std::vector > obufs{Halide::Buffer<>(*out0), Halide::Buffer<>(*out1)}; + std::vector> obufs{Halide::Buffer<>(*out0), Halide::Buffer<>(*out1)}; u3v.get(obufs); } return 0; @@ -1492,54 +1442,47 @@ int ION_EXPORT ion_bb_image_io_u3v_camera2( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_camera2); - -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_camera1_frame_count( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_camera1_frame_count( halide_buffer_t *, int32_t num_sensor, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * id_buf, halide_buffer_t* out) -{ + halide_buffer_t *id_buf, halide_buffer_t *out) { const std::string id(reinterpret_cast(id_buf->host)); return u3v_camera_frame_count(id, num_sensor, frame_sync, realtime_display_mode, out); } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_camera1_frame_count); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_camera2_frame_count( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_camera2_frame_count( halide_buffer_t *, halide_buffer_t *, int32_t num_sensor, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * id_buf, halide_buffer_t* out) -{ const std::string id(reinterpret_cast(id_buf->host)); + halide_buffer_t *id_buf, halide_buffer_t *out) { + const std::string id(reinterpret_cast(id_buf->host)); return u3v_camera_frame_count(id, num_sensor, frame_sync, realtime_display_mode, out); } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_camera2_frame_count); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_gendc_camera1( - halide_buffer_t * id_buf, +extern "C" int ION_EXPORT ion_bb_image_io_u3v_gendc_camera1( + halide_buffer_t *id_buf, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, bool enable_control, - halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, halide_buffer_t * pixel_format_buf, + halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, halide_buffer_t *pixel_format_buf, double gain0, double exposure0, - halide_buffer_t * out_gendc - ) -{ + halide_buffer_t *out_gendc) { using namespace Halide; int num_output = 1; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); if (out_gendc->is_bounds_query()) { return 0; } auto &u3v(ion::bb::image_io::U3VGenDC::get_instance(id, 1, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); - // set gain & exposure - if (enable_control){ + // set gain & exposure + if (enable_control) { ion::log::debug("Setting gain0:{} exposure0:{}", gain0, exposure0); u3v.set_gain(0, gain_key, gain0); u3v.set_exposure(0, exposure_key, exposure0); @@ -1558,32 +1501,29 @@ int ION_EXPORT ion_bb_image_io_u3v_gendc_camera1( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_gendc_camera1); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_gendc_camera2( - halide_buffer_t * id_buf, +extern "C" int ION_EXPORT ion_bb_image_io_u3v_gendc_camera2( + halide_buffer_t *id_buf, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, bool enable_control, - halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, halide_buffer_t * pixel_format_buf, + halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, halide_buffer_t *pixel_format_buf, double gain0, double exposure0, double gain1, double exposure1, - halide_buffer_t * out_gendc0, halide_buffer_t * out_gendc1 - ) -{ + halide_buffer_t *out_gendc0, halide_buffer_t *out_gendc1) { using namespace Halide; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); - if (out_gendc0->is_bounds_query() || out_gendc1->is_bounds_query() ) { + if (out_gendc0->is_bounds_query() || out_gendc1->is_bounds_query()) { return 0; } std::vector obufs{out_gendc0->host, out_gendc1->host}; auto &u3v(ion::bb::image_io::U3VGenDC::get_instance(id, 2, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); - // set gain & exposure + // set gain & exposure if (enable_control) { ion::log::debug("Setting gain0:{} exposure0:{}", gain0, exposure0); u3v.set_gain(0, gain_key, gain0); @@ -1593,7 +1533,7 @@ int ION_EXPORT ion_bb_image_io_u3v_gendc_camera2( u3v.set_gain(1, gain_key, gain1); u3v.set_exposure(1, exposure_key, exposure1); } - u3v.get(obufs); + u3v.get(obufs); return 0; } catch (const std::exception &e) { @@ -1606,40 +1546,38 @@ int ION_EXPORT ion_bb_image_io_u3v_gendc_camera2( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_gendc_camera2); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_multiple_camera1( - halide_buffer_t * id_buf, +extern "C" int ION_EXPORT ion_bb_image_io_u3v_multiple_camera1( + halide_buffer_t *id_buf, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, bool enable_control, - halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, halide_buffer_t * pixel_format_buf, + halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, halide_buffer_t *pixel_format_buf, double gain0, double exposure0, - halide_buffer_t * out0) -{ + halide_buffer_t *out0) { using namespace Halide; int num_output = 1; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); std::vector> obufs{Halide::Buffer<>(*out0)}; if (out0->is_bounds_query()) { return 0; } - if(force_sim_mode){ - auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, num_output, width, height, fps, pixel_format)); - u3v.get(obufs); - }else{ - auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, num_output, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); - if (enable_control) { - // set gain & exposure + if (force_sim_mode) { + auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, num_output, width, height, fps, pixel_format)); + u3v.get(obufs); + } else { + auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, num_output, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); + if (enable_control) { + // set gain & exposure ion::log::debug("Setting gain0:{} exposure0:{}", gain0, exposure0); u3v.set_gain(0, gain_key, gain0); u3v.set_exposure(0, exposure_key, exposure0); - } - u3v.get(obufs); + } + u3v.get(obufs); } return 0; } catch (const std::exception &e) { @@ -1652,35 +1590,33 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera1( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_multiple_camera1); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_multiple_camera2( - halide_buffer_t * id_buf, +extern "C" int ION_EXPORT ion_bb_image_io_u3v_multiple_camera2( + halide_buffer_t *id_buf, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, bool enable_control, - halide_buffer_t * gain_key_buf, halide_buffer_t * exposure_key_buf, halide_buffer_t * pixel_format_buf, + halide_buffer_t *gain_key_buf, halide_buffer_t *exposure_key_buf, halide_buffer_t *pixel_format_buf, double gain0, double exposure0, double gain1, double exposure1, - halide_buffer_t * out0, halide_buffer_t * out1) -{ + halide_buffer_t *out0, halide_buffer_t *out1) { using namespace Halide; int num_output = 2; try { const std::string id(reinterpret_cast(id_buf->host)); - const std::string gain_key(reinterpret_cast(gain_key_buf->host)); - const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); + const std::string gain_key(reinterpret_cast(gain_key_buf->host)); + const std::string exposure_key(reinterpret_cast(exposure_key_buf->host)); std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); std::vector> obufs{Halide::Buffer<>(*out0), Halide::Buffer<>(*out1)}; if (out0->is_bounds_query() || out1->is_bounds_query()) { return 0; } - if(force_sim_mode){ - auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, num_output, width, height, fps, pixel_format)); - u3v.get(obufs); - }else{ - auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, num_output, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); - if (enable_control) { - // set gain & exposure + if (force_sim_mode) { + auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, num_output, width, height, fps, pixel_format)); + u3v.get(obufs); + } else { + auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, num_output, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); + if (enable_control) { + // set gain & exposure ion::log::debug("Setting gain0:{} exposure0:{}", gain0, exposure0); u3v.set_gain(0, gain_key, gain0); u3v.set_exposure(0, exposure_key, exposure0); @@ -1688,7 +1624,7 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera2( u3v.set_gain(1, gain_key, gain1); u3v.set_exposure(1, exposure_key, exposure1); } - u3v.get(obufs); + u3v.get(obufs); } return 0; } catch (const std::exception &e) { @@ -1701,16 +1637,14 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera2( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_multiple_camera2); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count1( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count1( halide_buffer_t *, - halide_buffer_t * id_buf, int32_t num_sensor, + halide_buffer_t *id_buf, int32_t num_sensor, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * pixel_format_buf, - halide_buffer_t* out) -{ + halide_buffer_t *pixel_format_buf, + halide_buffer_t *out) { try { const std::string id(reinterpret_cast(id_buf->host)); @@ -1721,12 +1655,12 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count1( out->dim[0].extent = num_sensor; return 0; } - if(force_sim_mode){ - auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, 1, width, height, fps, pixel_format)); - u3v.get_frame_count(obufs); - }else{ - auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 1, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); - u3v.get_frame_count(obufs); + if (force_sim_mode) { + auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, 1, width, height, fps, pixel_format)); + u3v.get_frame_count(obufs); + } else { + auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 1, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); + u3v.get_frame_count(obufs); } return 0; @@ -1740,17 +1674,15 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count1( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_multiple_camera_frame_count1); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count2( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count2( halide_buffer_t *, halide_buffer_t *, - halide_buffer_t * id_buf, int32_t num_sensor, + halide_buffer_t *id_buf, int32_t num_sensor, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * pixel_format_buf, - halide_buffer_t * out0, halide_buffer_t * out1) -{ + halide_buffer_t *pixel_format_buf, + halide_buffer_t *out0, halide_buffer_t *out1) { try { const std::string id(reinterpret_cast(id_buf->host)); const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); @@ -1758,10 +1690,10 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count2( if (out0->is_bounds_query() || out1->is_bounds_query()) { return 0; } - if(force_sim_mode){ + if (force_sim_mode) { auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, 2, width, height, fps, pixel_format)); u3v.get_frame_count(obufs); - }else{ + } else { auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 2, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); u3v.get_frame_count(obufs); } @@ -1773,38 +1705,33 @@ int ION_EXPORT ion_bb_image_io_u3v_multiple_camera_frame_count2( ion::log::error("Unknown exception was thrown"); return 1; } - } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_multiple_camera_frame_count2); - -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_device_info1( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_device_info1( halide_buffer_t *, - halide_buffer_t * id_buf, int32_t num_sensor, + halide_buffer_t *id_buf, int32_t num_sensor, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * pixel_format_buf, - halide_buffer_t * out_deviceinfo - ) -{ + halide_buffer_t *pixel_format_buf, + halide_buffer_t *out_deviceinfo) { using namespace Halide; int num_output = 1; try { const std::string id(reinterpret_cast(id_buf->host)); const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); - if (out_deviceinfo->is_bounds_query()){ + if (out_deviceinfo->is_bounds_query()) { out_deviceinfo->dim[0].min = 0; out_deviceinfo->dim[0].extent = sizeof(ion::bb::image_io::rawHeader); return 0; } std::vector obufs{out_deviceinfo->host}; - if(force_sim_mode){ + if (force_sim_mode) { auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, 1, width, height, fps, pixel_format)); u3v.get_device_info(obufs); - }else{ + } else { auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 1, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); u3v.get_device_info(obufs); } @@ -1820,17 +1747,14 @@ int ION_EXPORT ion_bb_image_io_u3v_device_info1( } ION_REGISTER_EXTERN(ion_bb_image_io_u3v_device_info1); -extern "C" -int ION_EXPORT ion_bb_image_io_u3v_device_info2( +extern "C" int ION_EXPORT ion_bb_image_io_u3v_device_info2( halide_buffer_t *, halide_buffer_t *, - halide_buffer_t * id_buf, int32_t num_sensor, + halide_buffer_t *id_buf, int32_t num_sensor, bool force_sim_mode, int32_t width, int32_t height, float_t fps, bool frame_sync, bool realtime_display_mode, - halide_buffer_t * pixel_format_buf, - halide_buffer_t * deviceinfo0, halide_buffer_t * deviceinfo1 - ) -{ + halide_buffer_t *pixel_format_buf, + halide_buffer_t *deviceinfo0, halide_buffer_t *deviceinfo1) { using namespace Halide; try { @@ -1839,21 +1763,21 @@ int ION_EXPORT ion_bb_image_io_u3v_device_info2( const std::string pixel_format(reinterpret_cast(pixel_format_buf->host)); if (deviceinfo0->is_bounds_query() || deviceinfo1->is_bounds_query()) { - if (deviceinfo0->is_bounds_query()){ + if (deviceinfo0->is_bounds_query()) { deviceinfo0->dim[0].min = 0; deviceinfo0->dim[0].extent = sizeof(ion::bb::image_io::rawHeader); } - if (deviceinfo1->is_bounds_query()){ + if (deviceinfo1->is_bounds_query()) { deviceinfo1->dim[0].min = 0; deviceinfo1->dim[0].extent = sizeof(ion::bb::image_io::rawHeader); } return 0; } std::vector obufs{deviceinfo0->host, deviceinfo1->host}; - if(force_sim_mode){ + if (force_sim_mode) { auto &u3v(ion::bb::image_io::U3VFakeCam::get_instance(id, 2, width, height, fps, pixel_format)); u3v.get_device_info(obufs); - }else{ + } else { auto &u3v(ion::bb::image_io::U3VRealCam::get_instance(id, 2, frame_sync, realtime_display_mode, force_sim_mode, width, height, fps, pixel_format)); u3v.get_device_info(obufs); } diff --git a/src/bb/image-io/rt_v4l2.h b/src/bb/image-io/rt_v4l2.h index 34116359..7db5a81f 100644 --- a/src/bb/image-io/rt_v4l2.h +++ b/src/bb/image-io/rt_v4l2.h @@ -31,44 +31,42 @@ namespace ion { namespace bb { namespace image_io { #define BORDER_INTERPOLATE(x, l) (x < 0 ? 0 : (x >= l ? l - 1 : x)) -float weight(float input){ - float alpha = -1; - float x = (input < 0)? -input : input; - float x2 = x * x; - float x3 = x * x * x; - - if(x <= 1){ - return (alpha + 2) * x3 - (alpha + 3) * x2 + 1; - }else if(x < 2){ - return alpha * x3 - 5 * alpha * x2 + 8 * alpha * x - 4 * alpha; - }else{ - return 0x0; - } - } +float weight(float input) { + float alpha = -1; + float x = (input < 0) ? -input : input; + float x2 = x * x; + float x3 = x * x * x; + + if (x <= 1) { + return (alpha + 2) * x3 - (alpha + 3) * x2 + 1; + } else if (x < 2) { + return alpha * x3 - 5 * alpha * x2 + 8 * alpha * x - 4 * alpha; + } else { + return 0x0; + } +} template -void resize_bicubic(Halide::Runtime::Buffer& dst, - const Halide::Runtime::Buffer& src, - const int32_t src_width, const int32_t src_height, - const uint32_t dst_width, const uint32_t dst_height){ +void resize_bicubic(Halide::Runtime::Buffer &dst, + const Halide::Runtime::Buffer &src, + const int32_t src_width, const int32_t src_height, + const uint32_t dst_width, const uint32_t dst_height) { double min_value = static_cast(std::numeric_limits::min()); double max_value = static_cast(std::numeric_limits::max()); - for(int c = 0; c < 3; c++){ - for(int dh = 0; dh < dst_height; dh++){ - for(int dw = 0; dw < dst_width; dw++){ + for (int c = 0; c < 3; c++) { + for (int dh = 0; dh < dst_height; dh++) { + for (int dw = 0; dw < dst_width; dw++) { double value = 0; float totalWeight = 0; - float x = ((static_cast(dw)+ 0.5f) - *static_cast(src_width)) / static_cast(dst_width); + float x = ((static_cast(dw) + 0.5f) * static_cast(src_width)) / static_cast(dst_width); x -= 0.5f; - float y = (static_cast(dh)+ 0.5f) - *static_cast(src_height) / static_cast(dst_height); + float y = (static_cast(dh) + 0.5f) * static_cast(src_height) / static_cast(dst_height); y -= 0.5f; float dx = x - static_cast(floor(x)); float dy = y - static_cast(floor(y)); - for(int i = -1; i < 3; i++){ - for(int j = -1; j < 3; j++){ + for (int i = -1; i < 3; i++) { + for (int j = -1; j < 3; j++) { float wx = weight(j - dx); float wy = weight(i - dy); @@ -78,26 +76,24 @@ void resize_bicubic(Halide::Runtime::Buffer& dst, int sh = BORDER_INTERPOLATE((int)(y + i), src_height); T s = src(sw, sh, c); - value += w*s; + value += w * s; totalWeight += w; } - } - if(fabs(totalWeight)>0){ + if (fabs(totalWeight) > 0) { value /= fabs(totalWeight); - }else{ - value= 0; + } else { + value = 0; } value += 0.5; value = (value < min_value) ? min_value : value; value = (value > max_value) ? max_value : value; dst(dw, dh, c) = static_cast(value); - } + } } } } - std::unordered_map> image_cache; template @@ -112,13 +108,13 @@ bool get_image(const std::string &url, Halide::Runtime::Buffer &img, int widt std::tie(host_name, path_name) = parse_url(url); Halide::Runtime::Buffer img_buf; - if (host_name.empty() || path_name.empty()){ + if (host_name.empty() || path_name.empty()) { // fallback to local file - if (std::filesystem::exists(url)){ + if (std::filesystem::exists(url)) { img_buf = Halide::Tools::load_and_convert_image(url); img_loaded = true; } - }else{ + } else { httplib::Client cli(host_name.c_str()); cli.set_follow_location(true); auto res = cli.Get(path_name.c_str()); @@ -126,34 +122,34 @@ bool get_image(const std::string &url, Halide::Runtime::Buffer &img, int widt std::vector data(res->body.size()); data.resize(res->body.size()); std::memcpy(data.data(), res->body.c_str(), res->body.size()); - std::filesystem::path dir_path = std::filesystem::temp_directory_path() / "simulation_camera";; + std::filesystem::path dir_path = std::filesystem::temp_directory_path() / "simulation_camera"; + ; if (!std::filesystem::exists(dir_path)) { if (!std::filesystem::create_directory(dir_path)) { throw std::runtime_error("Failed to create temporary directory"); } } std::ofstream ofs(dir_path / std::filesystem::path(url).filename(), std::ios::binary); - ofs.write(reinterpret_cast(data.data()), data.size()); + ofs.write(reinterpret_cast(data.data()), data.size()); - img_buf = Halide::Tools::load_and_convert_image(dir_path / std::filesystem::path(url).filename()); + img_buf = Halide::Tools::load_and_convert_image(dir_path / std::filesystem::path(url).filename()); img_loaded = true; - } } - if (img_loaded){ //resize + if (img_loaded) { // resize int ori_width = img_buf.width(); int ori_height = img_buf.height(); - int channels = img_buf.channels(); - Halide::Runtime::Buffer resized (width_, height_, 3); + int channels = img_buf.channels(); + Halide::Runtime::Buffer resized(width_, height_, 3); resize_bicubic(resized, img_buf, ori_width, ori_height, width_, height_); - if (sizeof(T) == 4){ //float + if (sizeof(T) == 4) { // float // Buffer to Buffer range(0-1) Halide::Runtime::Buffer float_img = Halide::Tools::ImageTypeConversion::convert_image(resized, halide_type_of()); img.copy_from(float_img); - }else if (sizeof(T) == 1){ //uint_8 + } else if (sizeof(T) == 1) { // uint_8 img.copy_from(resized); - }else{ + } else { throw std::runtime_error("Unsupported image format"); } } @@ -184,10 +180,9 @@ class V4L2 { }; public: - static V4L2 &get_instance(int32_t id, int32_t index, int32_t fps, int32_t width, int32_t height, uint32_t pixel_format, float gain_r, float gain_g, float gain_b, float offset, int32_t bit_width, int32_t bit_shift, - bool force_sim_mode, const std::string& url) { + bool force_sim_mode, const std::string &url) { if (instances_.count(id) == 0) { instances_[id] = std::make_shared(id, index, fps, width, height, pixel_format, gain_r, gain_g, gain_b, offset, bit_width, bit_shift, force_sim_mode, url); } @@ -196,7 +191,7 @@ class V4L2 { V4L2(int32_t id, int32_t index, int32_t fps, int32_t width, int32_t height, uint32_t pixel_format, float gain_r, float gain_g, float gain_b, float offset, int32_t bit_width, int32_t bit_shift, - bool force_sim_mode, const std::string& url) + bool force_sim_mode, const std::string &url) : id_(id), index_(index), fps_(fps), width_(width), height_(height), pixel_format_(pixel_format), gain_r_(gain_r), gain_g_(gain_g), gain_b_(gain_b), offset_(offset), bit_width_(bit_width), bit_shift_(bit_shift), sim_mode_(force_sim_mode), url_(url) { @@ -211,19 +206,22 @@ class V4L2 { struct stat st; if (-1 == stat(dev_name, &st)) { log::warn("Fallback to simulation mode: Could not find {}", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } if (!S_ISCHR(st.st_mode)) { log::warn("Fallback to simulation mode: {} is not proper device", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } fd_ = open(dev_name, O_RDWR | O_NONBLOCK, 0); if (-1 == fd_) { log::warn("Fallback to simulation mode: Cannot open {}: {}, {}", dev_name, errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -231,22 +229,26 @@ class V4L2 { if (-1 == xioctl(fd_, VIDIOC_QUERYCAP, &cap)) { if (EINVAL == errno) { log::warn("Fallback to simulation mode: {} is not V4L2 device", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } else { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_QUERYCAP", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } } if (!(cap.capabilities & V4L2_CAP_VIDEO_CAPTURE)) { log::warn("Fallback to simulation mode: {} is not video capture device", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } if (!(cap.capabilities & V4L2_CAP_STREAMING)) { log::warn("Fallback to simulation mode: {} s does not support streaming i/o", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -265,7 +267,8 @@ class V4L2 { } if (!supported) { log::warn("Fallback to simulation mode: {} does not support desired pixel format", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -281,30 +284,33 @@ class V4L2 { }; if (-1 == xioctl(fd_, VIDIOC_S_FMT, &fmt)) { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_S_FMT", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } if (width != fmt.fmt.pix.width || height != fmt.fmt.pix.height) { log::warn("Fallback to simulation mode: {} does not support desired resolution, expected({}x{}), actual({}x{})", dev_name, fmt.fmt.pix.width, fmt.fmt.pix.height, width, height); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } buffer_size_ = fmt.fmt.pix.sizeimage; struct v4l2_streamparm strmp { .type = V4L2_BUF_TYPE_VIDEO_CAPTURE, - }; if (-1 == xioctl(fd_, VIDIOC_G_PARM, &strmp)) { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_G_PARM", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } strmp.parm.capture.timeperframe.numerator = 1; strmp.parm.capture.timeperframe.denominator = fps; if (-1 == xioctl(fd_, VIDIOC_S_PARM, &strmp)) { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_S_PARM", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -319,11 +325,13 @@ class V4L2 { if (-1 == xioctl(fd_, VIDIOC_REQBUFS, &req)) { if (EINVAL == errno) { log::warn("Fallback to simulation mode: {} does not support memory mapping", dev_name); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } else { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_REQBUFS", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } } @@ -352,7 +360,8 @@ class V4L2 { /* enqueue an empty (capturing) or filled (output) buffer in the driver's incoming queue */ if (-1 == xioctl(fd_, VIDIOC_QBUF, &buf)) { log::warn("Fallback to simulation mode: {} error {}, {}", "VIDIOC_QBUF", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } } @@ -366,7 +375,8 @@ class V4L2 { /* Start streaming I/O */ if (-1 == xioctl(fd_, VIDIOC_STREAMON, &type)) { log::warn("Fallback to simulation mode: {} error {}, {}\n", "VIDIOC_STREAMON", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -376,7 +386,8 @@ class V4L2 { efd_ = epoll_create1(0); if (-1 == efd_) { log::warn("Fallback to simulation mode: {} error {}, {}", "epoll_create1", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } @@ -386,7 +397,8 @@ class V4L2 { if (-1 == epoll_ctl(efd_, EPOLL_CTL_ADD, fd_, &event)) { log::warn("Fallback to simulation mode: {} error {}, {}", "epoll_ctl", errno, strerror(errno)); - sim_mode_ = true;; + sim_mode_ = true; + ; return; } } @@ -432,13 +444,13 @@ class V4L2 { // Fill by dummy image Halide::Runtime::Buffer img_16(width_, height_); img_16.fill(0); - for(int y = (index_ / 2) % 2; y < height_ ; y+=2){ - for(int x = index_ % 2; x < width_; x+=2){ - img_16(x, y) = 65535 ; + for (int y = (index_ / 2) % 2; y < height_; y += 2) { + for (int x = index_ % 2; x < width_; x += 2) { + img_16(x, y) = 65535; } - } + } buf.copy_from(img_16); - auto size = width_ * height_ * sizeof(uint16_t); + auto size = width_ * height_ * sizeof(uint16_t); std::vector data(size); memcpy(data.data(), img_16.data(), size); ion::bb::image_io::image_cache[id_] = data; @@ -446,47 +458,47 @@ class V4L2 { } img.for_each_element([&](int x, int y, int c) { - img(x, y, c) = pow((float)(img(x, y, c)), 2.2) ; + img(x, y, c) = pow((float)(img(x, y, c)), 2.2); }); - std::vector r_planar(width_*height_); - std::vector g_planar(width_*height_); - std::vector b_planar(width_*height_); + std::vector r_planar(width_ * height_); + std::vector g_planar(width_ * height_); + std::vector b_planar(width_ * height_); - memcpy(r_planar.data(), img.data(), width_* height_* sizeof(float)); - memcpy(g_planar.data(), img.data() + width_ * height_ * 1, width_* height_* sizeof(float)); - memcpy(b_planar.data(), img.data() + width_ * height_ * 2, width_* height_* sizeof(float)); + memcpy(r_planar.data(), img.data(), width_ * height_ * sizeof(float)); + memcpy(g_planar.data(), img.data() + width_ * height_ * 1, width_ * height_ * sizeof(float)); + memcpy(b_planar.data(), img.data() + width_ * height_ * 2, width_ * height_ * sizeof(float)); - std::transform(r_planar.begin(), r_planar.end(), r_planar.begin(), [&](float x){return std::min(std::max(x * gain_r_ + offset_, 0.f), 1.f);}); - std::transform(g_planar.begin(), g_planar.end(), g_planar.begin(), [&](float x){return std::min(std::max(x * gain_g_ + offset_, 0.f), 1.f);}); - std::transform(b_planar.begin(), b_planar.end(), b_planar.begin(), [&](float x){return std::min(std::max(x * gain_b_ + offset_, 0.f), 1.f);}); + std::transform(r_planar.begin(), r_planar.end(), r_planar.begin(), [&](float x) { return std::min(std::max(x * gain_r_ + offset_, 0.f), 1.f); }); + std::transform(g_planar.begin(), g_planar.end(), g_planar.begin(), [&](float x) { return std::min(std::max(x * gain_g_ + offset_, 0.f), 1.f); }); + std::transform(b_planar.begin(), b_planar.end(), b_planar.begin(), [&](float x) { return std::min(std::max(x * gain_b_ + offset_, 0.f), 1.f); }); - std::vector processed_img_arr(width_*height_); + std::vector processed_img_arr(width_ * height_); int idx = 0; for (int j = 0; j < height_; j++) { int evenRow = j % 2 == 0; for (int i = 0; i < width_; i++) { int evenCol = i % 2 == 0; switch (bayer_pattern(pixel_format_)) { - case 0: // RGGB - processed_img_arr[idx] = evenRow ? (evenCol ? r_planar[idx] : g_planar[idx]) : (evenCol ? g_planar[idx] : b_planar[idx]); - break; - case 1: // BGGR - processed_img_arr[idx] = evenRow ? (evenCol ? b_planar[idx] : g_planar[idx]) : (evenCol ? g_planar[idx] : r_planar[idx]); - break; - case 2: // GRBG - processed_img_arr[idx] = evenRow ? (evenCol ? g_planar[idx] : r_planar[idx]) : (evenCol ? b_planar[idx] : g_planar[idx]); - break; - case 3: // GBRG - processed_img_arr[idx] = evenRow ? (evenCol ? g_planar[idx] : b_planar[idx]) : (evenCol ? r_planar[idx] : g_planar[idx]); - break; + case 0: // RGGB + processed_img_arr[idx] = evenRow ? (evenCol ? r_planar[idx] : g_planar[idx]) : (evenCol ? g_planar[idx] : b_planar[idx]); + break; + case 1: // BGGR + processed_img_arr[idx] = evenRow ? (evenCol ? b_planar[idx] : g_planar[idx]) : (evenCol ? g_planar[idx] : r_planar[idx]); + break; + case 2: // GRBG + processed_img_arr[idx] = evenRow ? (evenCol ? g_planar[idx] : r_planar[idx]) : (evenCol ? b_planar[idx] : g_planar[idx]); + break; + case 3: // GBRG + processed_img_arr[idx] = evenRow ? (evenCol ? g_planar[idx] : b_planar[idx]) : (evenCol ? r_planar[idx] : g_planar[idx]); + break; } - idx+=1; + idx += 1; } } - std::vector bit_shifted_img_arr(width_*height_); - for (int i = 0;i bit_shifted_img_arr(width_ * height_); + for (int i = 0; i < width_ * height_; i++) { float val = processed_img_arr[i]; val *= (float)((1 << bit_width_) - 1); val = val * (float)(1 << bit_shift_) + val / (float)(1 << (bit_width_ - bit_shift_)); @@ -498,32 +510,30 @@ class V4L2 { ion::bb::image_io::image_cache[id_] = data; } - -// Created by : Harris Zhu -// Filename : rgb2I420.cpp -// Avthor : Harris Zhu -//======================================================================= + // Created by : Harris Zhu + // Filename : rgb2I420.cpp + // Avthor : Harris Zhu + //======================================================================= #include #include #define BORDER_INTERPOLATE(x, l) (x < 0 ? 0 : (x >= l ? l - 1 : x)) template - void rgb2YCrCb(T *destination, Halide::Runtime::Buffer &rgb, int width, int height){ - for(int y = 0; y < height ; y++){ - for(int x = 0; x < width; x++){ + void rgb2YCrCb(T *destination, Halide::Runtime::Buffer &rgb, int width, int height) { + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { T r = rgb(x, y, 0); T g = rgb(x, y, 1); T b = rgb(x, y, 2); - T Yy = 0.299 * r + 0.587 * g + 0.114 * b ; - T Cr = (r-Yy) * 0.713 + 128; - T Cb = (b-Yy) * 0.564 + 128; - destination[(x+y*width)*3] = Yy; - destination[(x+y*width)*3+1] = Cr; - destination[(x+y*width)*3+2] = Cb; + T Yy = 0.299 * r + 0.587 * g + 0.114 * b; + T Cr = (r - Yy) * 0.713 + 128; + T Cb = (b - Yy) * 0.564 + 128; + destination[(x + y * width) * 3] = Yy; + destination[(x + y * width) * 3 + 1] = Cr; + destination[(x + y * width) * 3 + 2] = Cb; } } - } template @@ -533,8 +543,9 @@ class V4L2 { memcpy(buf.data(), it->second.data(), it->second.size()); return; } - Halide::Runtime::Buffer img (width_, height_, 3); - bool is_loaded = get_image(url_, img, width_, height_);; + Halide::Runtime::Buffer img(width_, height_, 3); + bool is_loaded = get_image(url_, img, width_, height_); + ; std::vector yuyv_img(2 * width_ * height_); if (!is_loaded) { // Fill by dummy image @@ -551,9 +562,9 @@ class V4L2 { for (int y = 0; y < height_; ++y) { for (int x = 0; x < width_; ++x) { // Y - yuyv_img[2 * width_ * y + 2 * x + 0] = yuv[( x + y * width_) * 3]; + yuyv_img[2 * width_ * y + 2 * x + 0] = yuv[(x + y * width_) * 3]; // Cb or Cr - yuyv_img[2 * width_ * y + 2 * x + 1] = ((x % 2) == 1) ? yuv[(y * width_ + x) * 3 + 1]:yuv[(y * width_ + x) * 3 + 2]; + yuyv_img[2 * width_ * y + 2 * x + 1] = ((x % 2) == 1) ? yuv[(y * width_ + x) * 3 + 1] : yuv[(y * width_ + x) * 3 + 2]; } } } @@ -563,7 +574,6 @@ class V4L2 { return; } - template void generate(Halide::Runtime::Buffer &buf) { @@ -632,7 +642,6 @@ class V4L2 { memcpy(buf.data(), reinterpret_cast(next_buffer_.m.userptr), buf.size_in_bytes()); } - private: int fd_; std::vector buffers_; @@ -675,7 +684,7 @@ extern "C" ION_EXPORT int ion_bb_image_io_v4l2( int32_t fps, int32_t width, int32_t height, uint32_t pixel_format, - uint32_t force_sim_mode, // Do not use bool to avoid LLVM codegen failure + uint32_t force_sim_mode, // Do not use bool to avoid LLVM codegen failure // Parameters for simulation halide_buffer_t *url_buf, float gain_r, float gain_g, float gain_b, @@ -693,7 +702,7 @@ extern "C" ION_EXPORT int ion_bb_image_io_v4l2( return 0; } - auto &v4l2(ion::bb::image_io::V4L2::get_instance(instance_id, index, fps, width, height, pixel_format, gain_r, gain_g, gain_b, offset, bit_width, bit_shift, static_cast(force_sim_mode), reinterpret_cast(url_buf->host))); + auto &v4l2(ion::bb::image_io::V4L2::get_instance(instance_id, index, fps, width, height, pixel_format, gain_r, gain_g, gain_b, offset, bit_width, bit_shift, static_cast(force_sim_mode), reinterpret_cast(url_buf->host))); Halide::Runtime::Buffer obuf(*out); v4l2.get(obuf); @@ -719,7 +728,7 @@ extern "C" int ION_EXPORT ion_bb_image_io_camera(int32_t instance_id, int32_t in return 0; } - auto &v4l2(ion::bb::image_io::V4L2::get_instance(instance_id, index, fps, width, height, V4L2_PIX_FMT_YUYV, 1, 1, 1, 0, 8, 0, false, reinterpret_cast(url_buf->host))); + auto &v4l2(ion::bb::image_io::V4L2::get_instance(instance_id, index, fps, width, height, V4L2_PIX_FMT_YUYV, 1, 1, 1, 0, 8, 0, false, reinterpret_cast(url_buf->host))); Halide::Runtime::Buffer obuf(*out); v4l2.get(obuf); @@ -735,7 +744,4 @@ extern "C" int ION_EXPORT ion_bb_image_io_camera(int32_t instance_id, int32_t in } ION_REGISTER_EXTERN(ion_bb_image_io_camera) - - -#endif // ION_BB_IMAGE_IO_RT_V4L2_H - +#endif // ION_BB_IMAGE_IO_RT_V4L2_H diff --git a/src/bb/image-processing/bb.h b/src/bb/image-processing/bb.h index 106d0be5..1994cd94 100644 --- a/src/bb/image-processing/bb.h +++ b/src/bb/image-processing/bb.h @@ -788,8 +788,7 @@ class BilateralFilter2D : public BuildingBlock { r = { -static_cast(window_size), static_cast(window_size) * 2 + 1, -static_cast(window_size), static_cast(window_size) * 2 + 1, - "r" - }; + "r"}; color_diff = (input_mirror(x + r.x, y + r.y) - input_mirror(x, y)) * (input_mirror(x + r.x, y + r.y) - input_mirror(x, y)); sigma_inv(x, y) = 1 / sigma(x, y); @@ -863,8 +862,7 @@ class BilateralFilter3D : public BuildingBlock { r = { -static_cast(window_size), static_cast(window_size) * 2 + 1, -static_cast(window_size), static_cast(window_size) * 2 + 1, - "r" - }; + "r"}; color_diff = ColorDifference::calc( color_difference_method, @@ -951,8 +949,7 @@ class Convolution : public BuildingBlock { r = { -static_cast(window_size), static_cast(window_size) * 2 + 1, -static_cast(window_size), static_cast(window_size) * 2 + 1, - "r" - }; + "r"}; sum(x, y, Halide::_) += input(x + r.x, y + r.y, Halide::_) * kernel(r.x + window_size, r.y + window_size, Halide::_); output(x, y, Halide::_) = sum(x, y, Halide::_); } @@ -1470,7 +1467,7 @@ class FitImageToCenter : public BuildingBlock { Output output{"output", Halide::type_of(), D}; void generate() { - using namespace Halide; + using namespace Halide; Var x, y; @@ -1993,7 +1990,6 @@ class ColorDynamicAdjustment : public ion::BuildingBlock Expr bv = saturating_cast(cast(input(x, y, 2)) * gain_r); output(x, y, c) = select(c == 0, rv, c == 1, gv, bv); } - }; } // namespace image_processing diff --git a/src/bb/image-processing/rt.h b/src/bb/image-processing/rt.h index 6739a8ce..b9edf419 100644 --- a/src/bb/image-processing/rt.h +++ b/src/bb/image-processing/rt.h @@ -10,12 +10,12 @@ namespace image_processing { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // image_io -} // bb -} // ion +} // namespace image_processing +} // namespace bb +} // namespace ion diff --git a/src/bb/llm/bb.cc b/src/bb/llm/bb.cc index 85c6912c..bd2c72ca 100644 --- a/src/bb/llm/bb.cc +++ b/src/bb/llm/bb.cc @@ -26,20 +26,21 @@ namespace llm { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // llm -} // bb -} // ion +} // namespace llm +} // namespace bb +} // namespace ion #define ION_REGISTER_EXTERN(NAME) static auto ion_register_extern_##NAME = ion::bb::llm::RegisterExtern(#NAME, NAME); std::string escape_escape_sequences(const std::string &str_) { - auto str = str_;; + auto str = str_; + ; std::pair const sequences[]{ {'\a', 'a'}, {'\b', 'b'}, @@ -69,9 +70,9 @@ std::string escape_escape_sequences(const std::string &str_) { // NOTE: Originally defined in llama.cpp // struct llava_context { - struct clip_ctx * ctx_clip = NULL; - struct llama_context * ctx_llama = NULL; - struct llama_model * model = NULL; + struct clip_ctx *ctx_clip = NULL; + struct llama_context *ctx_llama = NULL; + struct llama_model *model = NULL; }; struct clip_image_u8 { @@ -81,10 +82,10 @@ struct clip_image_u8 { std::vector buf; }; -static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { - int N = (int) tokens.size(); +static bool eval_tokens(struct llama_context *ctx_llama, std::vector tokens, int n_batch, int *n_past) { + int N = (int)tokens.size(); for (int i = 0; i < N; i += n_batch) { - int n_eval = (int) tokens.size() - i; + int n_eval = (int)tokens.size() - i; if (n_eval > n_batch) { n_eval = n_batch; } @@ -97,22 +98,22 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens; tokens.push_back(id); return eval_tokens(ctx_llama, tokens, 1, n_past); } -static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ - std::string str2 = str; +static bool eval_string(struct llama_context *ctx_llama, const char *str, int n_batch, int *n_past, bool add_bos) { + std::string str2 = str; std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); eval_tokens(ctx_llama, embd_inp, n_batch, n_past); return true; } -static const char * sample(struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_llama, - int * n_past) { +static const char *sample(struct llama_sampling_context *ctx_sampling, + struct llama_context *ctx_llama, + int *n_past) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); llama_sampling_accept(ctx_sampling, ctx_llama, id, true); static std::string ret; @@ -125,14 +126,14 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, return ret.c_str(); } -struct llava_image_embed * llava_image_embed_make_with_rawbytes(struct clip_ctx * ctx_clip, int n_threads, const std::vector& buf, int32_t width, int32_t height) { - clip_image_u8 * img = clip_image_u8_init(); +struct llava_image_embed *llava_image_embed_make_with_rawbytes(struct clip_ctx *ctx_clip, int n_threads, const std::vector &buf, int32_t width, int32_t height) { + clip_image_u8 *img = clip_image_u8_init(); img->nx = width; img->ny = height; img->buf.resize(3 * img->nx * img->ny); - memcpy(img->buf.data(), reinterpret_cast(buf.data()), buf.size()); + memcpy(img->buf.data(), reinterpret_cast(buf.data()), buf.size()); - float* image_embed = NULL; + float *image_embed = NULL; int n_image_pos = 0; bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); if (!image_embed_result) { @@ -141,16 +142,16 @@ struct llava_image_embed * llava_image_embed_make_with_rawbytes(struct clip_ctx } clip_image_u8_free(img); - auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + auto result = (llava_image_embed *)malloc(sizeof(llava_image_embed)); result->embed = image_embed; result->n_image_pos = n_image_pos; return result; } -static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params *params, const std::vector& buf, int32_t width, int32_t height) { +static struct llava_image_embed *load_image(llava_context *ctx_llava, gpt_params *params, const std::vector &buf, int32_t width, int32_t height) { // load and preprocess the image - llava_image_embed * embed = NULL; + llava_image_embed *embed = NULL; embed = llava_image_embed_make_with_rawbytes(ctx_llava->ctx_clip, params->n_threads, buf, width, height); if (!embed) { throw std::runtime_error("Failed to embed image from rawbytes"); @@ -159,7 +160,7 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para return embed; } -static std::string process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const std::string & prompt) { +static std::string process_prompt(struct llava_context *ctx_llava, struct llava_image_embed *image_embed, gpt_params *params, const std::string &prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; @@ -172,13 +173,13 @@ static std::string process_prompt(struct llava_context * ctx_llava, struct llava user_prompt = prompt.substr(image_pos + std::string("").length()); if (params->verbose_prompt) { auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { + for (int i = 0; i < (int)tmp.size(); i++) { ion::log::info("{:6d} -> '{}'", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i])); } } if (params->verbose_prompt) { auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { + for (int i = 0; i < (int)tmp.size(); i++) { ion::log::info("{:6d} -> '{}'", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i])); } } @@ -188,7 +189,7 @@ static std::string process_prompt(struct llava_context * ctx_llava, struct llava user_prompt = prompt + "\nASSISTANT:"; if (params->verbose_prompt) { auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { + for (int i = 0; i < (int)tmp.size(); i++) { ion::log::info("{:6d} -> '{}'", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i])); } } @@ -199,16 +200,16 @@ static std::string process_prompt(struct llava_context * ctx_llava, struct llava eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); // generate the response - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context *ctx_sampling = llama_sampling_init(params->sparams); std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char *tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); response += tmp; if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior - if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) - if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 - if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 + if (strstr(tmp, "###")) break; // Yi-VL behavior + if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) + if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 + if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 } ion::log::debug("system_prompt:{} user_prompt:{} response:{}", system_prompt, user_prompt, escape_escape_sequences(response)); @@ -218,34 +219,33 @@ static std::string process_prompt(struct llava_context * ctx_llava, struct llava return response; } - -static struct llava_context * llava_init(gpt_params * params) { +static struct llava_context *llava_init(gpt_params *params) { llama_log_set(nullptr, nullptr); - const char * clip_path = params->mmproj.c_str(); + const char *clip_path = params->mmproj.c_str(); auto prompt = params->prompt; if (prompt.empty()) { prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/1); llama_backend_init(); llama_numa_init(params->numa); llama_model_params model_params = llama_model_params_from_gpt_params(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model *model = llama_load_model_from_file(params->model.c_str(), model_params); if (model == NULL) { ion::log::error("Unable to load model"); return NULL; } llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); - ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings + ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context *ctx_llama = llama_new_context_with_model(model, ctx_params); if (ctx_llama == NULL) { ion::log::error("Failed to create the llama_context"); @@ -276,7 +276,7 @@ static void llava_new_context_llama(llava_context *ctx_llava, gpt_params *params ctx_llava->ctx_llama = ctx_llama; } -static llava_context * llava_init_without_ctx_llama(gpt_params *params) { +static llava_context *llava_init_without_ctx_llama(gpt_params *params) { const char *clip_path = params->mmproj.c_str(); auto prompt = params->prompt; @@ -306,7 +306,7 @@ static llava_context * llava_init_without_ctx_llama(gpt_params *params) { return ctx_llava; } -static void llava_free(struct llava_context * ctx_llava) { +static void llava_free(struct llava_context *ctx_llava) { if (ctx_llava->ctx_clip) { clip_free(ctx_llava->ctx_clip); ctx_llava->ctx_clip = NULL; @@ -329,8 +329,8 @@ class Llava { return llava; } - static void release_instance(const std::string& id) { - auto& llava = get_instance(); + static void release_instance(const std::string &id) { + auto &llava = get_instance(); llava.keep_running_ = false; llava.thread_.join(); llava_free(llava.ctx_llava_); @@ -341,8 +341,7 @@ class Llava { ~Llava() { } - bool is_initialized() - { + bool is_initialized() { return initialized_; } @@ -360,7 +359,7 @@ class Llava { initialized_ = true; } - std::string process(const std::vector& buf, const std::string& prompt) { + std::string process(const std::vector &buf, const std::string &prompt) { auto image_embed = load_image(ctx_llava_, ¶ms_, buf, width_, height_); @@ -369,15 +368,15 @@ class Llava { llava_image_embed_free(image_embed); - response = response.substr(response.find_last_of('\n')+1); - response = response.substr(0, response.find_last_of("")-4); + response = response.substr(response.find_last_of('\n') + 1); + response = response.substr(0, response.find_last_of("") - 4); llama_kv_cache_clear(ctx_llava_->ctx_llama); return response; } - void post(const Halide::Runtime::Buffer& ibuf, const std::string& prompt) { + void post(const Halide::Runtime::Buffer &ibuf, const std::string &prompt) { std::unique_lock lock(mutex_); // clear old @@ -385,17 +384,18 @@ class Llava { task_queue_.pop(); } - task_queue_.emplace(std::make_shared>(ibuf.data(), ibuf.data()+ibuf.size_in_bytes()), prompt); + task_queue_.emplace(std::make_shared>(ibuf.data(), ibuf.data() + ibuf.size_in_bytes()), prompt); cv_.notify_one(); } - std::string retrieve(){ + std::string retrieve() { std::unique_lock lock(mutex_); return response_; } private: - Llava() : keep_running_(true), initialized_(false) { + Llava() + : keep_running_(true), initialized_(false) { params_.model = "ggml-mistral-q_4_k.gguf"; params_.mmproj = "mmproj-mistral7b-f16-q6_k.gguf"; // params_.model = "llava-phi-3-mini-gguf/ggml-model-int4.gguf"; @@ -425,11 +425,10 @@ class Llava { } } - static void entry_point(Llava* obj) { + static void entry_point(Llava *obj) { try { obj->thread_main(); - } - catch (const std::exception& e) { + } catch (const std::exception &e) { ::std::unique_lock<::std::mutex> lock(obj->mutex_); ion::log::error(e.what()); obj->ep_ = ::std::current_exception(); @@ -439,7 +438,6 @@ class Llava { gpt_params params_; llava_context *ctx_llava_; - std::thread thread_; std::mutex mutex_; std::condition_variable cv_; @@ -454,19 +452,17 @@ class Llava { std::string response_; }; -} // rt -} // llm -} // bb -} // ion +} // namespace rt +} // namespace llm +} // namespace bb +} // namespace ion -extern "C" -int ION_EXPORT ion_bb_llm_llava_dispose(const char *id) { +extern "C" int ION_EXPORT ion_bb_llm_llava_dispose(const char *id) { ion::bb::llm::rt::Llava::release_instance(id); return 0; } -extern "C" -ION_EXPORT int ion_bb_llm_llava(halide_buffer_t *in, halide_buffer_t *prompt, int32_t width, int32_t height, halide_buffer_t *out) { +extern "C" ION_EXPORT int ion_bb_llm_llava(halide_buffer_t *in, halide_buffer_t *prompt, int32_t width, int32_t height, halide_buffer_t *out) { try { if (in->is_bounds_query() || prompt->is_bounds_query()) { if (in->is_bounds_query()) { @@ -480,7 +476,7 @@ ION_EXPORT int ion_bb_llm_llava(halide_buffer_t *in, halide_buffer_t *prompt, in if (prompt->is_bounds_query()) { prompt->dim[0].min = 0; - prompt->dim[0].extent = 1024; // TBD + prompt->dim[0].extent = 1024; // TBD } return 0; @@ -490,12 +486,12 @@ ION_EXPORT int ion_bb_llm_llava(halide_buffer_t *in, halide_buffer_t *prompt, in Halide::Runtime::Buffer pbuf(*prompt); Halide::Runtime::Buffer obuf(*out); - auto& llava = ion::bb::llm::rt::Llava::get_instance(); + auto &llava = ion::bb::llm::rt::Llava::get_instance(); if (!llava.is_initialized()) { llava.init(width, height); } - //auto response = llava.process(ibuf, std::string(reinterpret_cast(pbuf.data()))); - llava.post(ibuf, std::string(reinterpret_cast(pbuf.data()))); + // auto response = llava.process(ibuf, std::string(reinterpret_cast(pbuf.data()))); + llava.post(ibuf, std::string(reinterpret_cast(pbuf.data()))); auto response = llava.retrieve(); obuf.fill(0); diff --git a/src/bb/llm/bb.h b/src/bb/llm/bb.h index 97abdca6..e62a9376 100644 --- a/src/bb/llm/bb.h +++ b/src/bb/llm/bb.h @@ -18,13 +18,13 @@ class Llava : public BuildingBlock { BuildingBlockParam height{"height", 480}; void generate() { - using namespace Halide; + using namespace Halide; // NOTE: These tricks is required for the input parameter which is passed as an external function argument Func input_; input_(_) = input(_); input_.compute_root(); - + Func prompt_; prompt_(_) = prompt(_); prompt_.compute_root(); diff --git a/src/bb/llm/clip.h b/src/bb/llm/clip.h index 477076a2..c7524a4f 100644 --- a/src/bb/llm/clip.h +++ b/src/bb/llm/clip.h @@ -6,17 +6,17 @@ #include #ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define CLIP_API __declspec(dllexport) -# else -# define CLIP_API __declspec(dllimport) -# endif -# else -# define CLIP_API __attribute__ ((visibility ("default"))) -# endif +#if defined(_WIN32) && !defined(__MINGW32__) +#ifdef LLAMA_BUILD +#define CLIP_API __declspec(dllexport) #else -# define CLIP_API +#define CLIP_API __declspec(dllimport) +#endif +#else +#define CLIP_API __attribute__((visibility("default"))) +#endif +#else +#define CLIP_API #endif struct clip_ctx; @@ -28,59 +28,59 @@ extern "C" { struct clip_ctx; struct clip_image_u8_batch { - struct clip_image_u8 * data; + struct clip_image_u8 *data; size_t size; }; struct clip_image_f32_batch { - struct clip_image_f32 * data; + struct clip_image_f32 *data; size_t size; }; -CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); -CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity); +CLIP_API struct clip_ctx *clip_model_load(const char *fname, int verbosity); +CLIP_API struct clip_ctx *clip_model_load_cpu(const char *fname, int verbosity); -CLIP_API void clip_free(struct clip_ctx * ctx); +CLIP_API void clip_free(struct clip_ctx *ctx); -CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); +CLIP_API size_t clip_embd_nbytes(const struct clip_ctx *ctx); -CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); +CLIP_API int32_t clip_image_size(const struct clip_ctx *ctx); +CLIP_API int32_t clip_patch_size(const struct clip_ctx *ctx); +CLIP_API int32_t clip_hidden_size(const struct clip_ctx *ctx); // TODO: should be enum, not string -CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); +CLIP_API const char *clip_patch_merge_type(const struct clip_ctx *ctx); -CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); +CLIP_API const int32_t *clip_image_grid(const struct clip_ctx *ctx); -CLIP_API int clip_n_patches (const struct clip_ctx * ctx); -CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); +CLIP_API int clip_n_patches(const struct clip_ctx *ctx); +CLIP_API int clip_n_mmproj_embd(const struct clip_ctx *ctx); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_u8 *clip_image_u8_init(); +CLIP_API struct clip_image_f32 *clip_image_f32_init(); -CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); -CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); -CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); -CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +CLIP_API void clip_image_u8_free(struct clip_image_u8 *img); +CLIP_API void clip_image_f32_free(struct clip_image_f32 *img); +CLIP_API void clip_image_u8_batch_free(struct clip_image_u8_batch *batch); +CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch *batch); -CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); +CLIP_API bool clip_image_load_from_file(const char *fname, struct clip_image_u8 *img); /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ -CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); +CLIP_API bool clip_image_load_from_bytes(const unsigned char *bytes, size_t bytes_length, struct clip_image_u8 *img); /** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ -CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); +CLIP_API bool clip_image_preprocess(struct clip_ctx *ctx, const struct clip_image_u8 *img, struct clip_image_f32_batch *res_imgs); -CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); +CLIP_API struct ggml_tensor *clip_get_newline_tensor(const struct clip_ctx *ctx); -CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); -CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); +CLIP_API bool clip_image_encode(struct clip_ctx *ctx, int n_threads, struct clip_image_f32 *img, float *vec); +CLIP_API bool clip_image_batch_encode(struct clip_ctx *ctx, int n_threads, const struct clip_image_f32_batch *imgs, float *vec); -CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); +CLIP_API bool clip_model_quantize(const char *fname_inp, const char *fname_out, int itype); #ifdef __cplusplus } #endif -#endif // CLIP_H +#endif // CLIP_H diff --git a/src/bb/llm/common.h b/src/bb/llm/common.h index bccd74a9..7a54ae19 100644 --- a/src/bb/llm/common.h +++ b/src/bb/llm/common.h @@ -22,23 +22,32 @@ #define DIRECTORY_SEPARATOR '\\' #else #define DIRECTORY_SEPARATOR '/' -#endif // _WIN32 - -#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) -#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) - -#define print_build_info() do { \ - fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ - fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ -} while(0) +#endif // _WIN32 + +#define die(msg) \ + do { \ + fputs("error: " msg "\n", stderr); \ + exit(1); \ + } while (0) +#define die_fmt(fmt, ...) \ + do { \ + fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); \ + exit(1); \ + } while (0) + +#define print_build_info() \ + do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ + } while (0) #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" // build info extern int LLAMA_BUILD_NUMBER; -extern char const * LLAMA_COMMIT; -extern char const * LLAMA_COMPILER; -extern char const * LLAMA_BUILD_TARGET; +extern char const *LLAMA_COMMIT; +extern char const *LLAMA_COMPILER; +extern char const *LLAMA_BUILD_TARGET; struct llama_control_vector_load_info; @@ -54,142 +63,142 @@ int32_t cpu_get_num_math(); // struct gpt_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed + uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = cpu_get_num_math(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads = cpu_get_num_math(); + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = -1.0f; // KV cache defragmentation threshold - std::string rpc_servers = ""; // comma separated list of RPC servers + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f; // YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + float defrag_thold = -1.0f; // KV cache defragmentation threshold + std::string rpc_servers = ""; // comma separated list of RPC servers ggml_backend_sched_eval_callback cb_eval = nullptr; - void * cb_eval_user_data = nullptr; + void *cb_eval_user_data = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; - enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings // // sampling parameters struct llama_sampling_params sparams; - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding - std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file - std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted - std::string logdir = ""; // directory in which to save YAML log files - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits + std::string model = ""; // model path + std::string model_draft = ""; // draft model for speculative decoding + std::string model_alias = "unknown"; // model alias + std::string model_url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file + std::string prompt = ""; + std::string prompt_file = ""; // store the external prompt file name + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::vector antiprompt; // string upon seeing which more user input is prompted + std::string logdir = ""; // directory in which to save YAML log files + std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding + std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding + std::string logits_file = ""; // file for saving *all* logits std::vector kv_overrides; // TODO: avoid tuple, use struct - std::vector> lora_adapter; // lora adapter path with user defined scale - std::string lora_base = ""; // base model path for the lora adapter - - std::vector control_vectors; // control vector with user defined scale - - int32_t control_vector_layer_start = -1; // layer range for control vector - int32_t control_vector_layer_end = -1; // layer range for control vector - - int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line - // (which is more convenient to use for plotting) - // - bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - - bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt - size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed - - bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt - size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - - bool kl_divergence = false; // compute KL divergence - - bool random_prompt = false; // do not randomize prompt if none provided - bool use_color = false; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) - bool chatml = false; // chatml mode (used for models trained on chatml syntax) - bool prompt_cache_all = false; // save user input and generations to prompt cache - bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it - - bool embedding = false; // get only sentence embedding - bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" - bool interactive_first = false; // wait for user input immediately - bool multiline_input = false; // reverse the usage of `\` - bool simple_io = false; // improves compatibility with subprocesses and limited consoles - bool cont_batching = true; // insert new sequences for decoding on-the-fly - bool flash_attn = false; // flash attention - - bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens - bool instruct = false; // instruction mode (used for Alpaca models) - bool logits_all = false; // return logits for all tokens in the batch - bool use_mmap = true; // use mmap for faster loads - bool use_mlock = false; // use mlock to keep model in memory - bool verbose_prompt = false; // print prompt tokens before generation - bool display_prompt = true; // print prompt before generation - bool infill = false; // use infill mode - bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes - bool no_kv_offload = false; // disable KV offloading - bool warmup = true; // warmup run - bool check_tensors = false; // validate tensor data - - std::string cache_type_k = "f16"; // KV cache data type for the K - std::string cache_type_v = "f16"; // KV cache data type for the V + std::vector> lora_adapter; // lora adapter path with user defined scale + std::string lora_base = ""; // base model path for the lora adapter + + std::vector control_vectors; // control vector with user defined scale + + int32_t control_vector_layer_start = -1; // layer range for control vector + int32_t control_vector_layer_end = -1; // layer range for control vector + + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + + bool kl_divergence = false; // compute KL divergence + + bool random_prompt = false; // do not randomize prompt if none provided + bool use_color = false; // use color to distinguish generations and inputs + bool interactive = false; // interactive mode + bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode + bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) + bool chatml = false; // chatml mode (used for models trained on chatml syntax) + bool prompt_cache_all = false; // save user input and generations to prompt cache + bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + + bool embedding = false; // get only sentence embedding + bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" + bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention + + bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool ignore_eos = false; // ignore generated EOS tokens + bool instruct = false; // instruction mode (used for Alpaca models) + bool logits_all = false; // return logits for all tokens in the batch + bool use_mmap = true; // use mmap for faster loads + bool use_mlock = false; // use mlock to keep model in memory + bool verbose_prompt = false; // print prompt tokens before generation + bool display_prompt = true; // print prompt before generation + bool infill = false; // use infill mode + bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes + bool no_kv_offload = false; // disable KV offloading + bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data + + std::string cache_type_k = "f16"; // KV cache data type for the K + std::string cache_type_v = "f16"; // KV cache data type for the V // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector - std::vector image; // path to image file(s) + std::string mmproj = ""; // path to multimodal projector + std::vector image; // path to image file(s) }; -void gpt_params_handle_model_default(gpt_params & params); +void gpt_params_handle_model_default(gpt_params ¶ms); -bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); -bool gpt_params_parse (int argc, char ** argv, gpt_params & params); -bool gpt_params_find_arg (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); -void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params); +bool gpt_params_parse_ex(int argc, char **argv, gpt_params ¶ms); +bool gpt_params_parse(int argc, char **argv, gpt_params ¶ms); +bool gpt_params_find_arg(int argc, char **argv, const std::string &arg, gpt_params ¶ms, int &i, bool &invalid_param); +void gpt_params_print_usage(int argc, char **argv, const gpt_params ¶ms); -std::string gpt_params_get_system_info(const gpt_params & params); +std::string gpt_params_get_system_info(const gpt_params ¶ms); // // String utils @@ -197,19 +206,19 @@ std::string gpt_params_get_system_info(const gpt_params & params); std::vector string_split(std::string input, char separator); -std::string string_strip(const std::string & str); +std::string string_strip(const std::string &str); std::string string_get_sortable_timestamp(); -std::string string_random_prompt(std::mt19937 & rng); +std::string string_random_prompt(std::mt19937 &rng); -bool string_parse_kv_override(const char * data, std::vector & overrides); -void string_process_escapes(std::string & input); +bool string_parse_kv_override(const char *data, std::vector &overrides); +void string_process_escapes(std::string &input); // // Filesystem utils // -bool fs_validate_filename(const std::string & filename); -bool fs_create_directory_with_parents(const std::string & path); +bool fs_validate_filename(const std::string &filename); +bool fs_create_directory_with_parents(const std::string &path); std::string fs_get_cache_directory(); @@ -218,24 +227,24 @@ std::string fs_get_cache_directory(); // // TODO: avoid tuplue, use struct -std::tuple llama_init_from_gpt_params(gpt_params & params); +std::tuple llama_init_from_gpt_params(gpt_params ¶ms); -struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params ¶ms); +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params ¶ms); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); +struct llama_model *llama_load_model_from_url(const char *model_url, const char *path_model, const struct llama_model_params ¶ms); +struct llama_model *llama_load_model_from_hf(const char *repo, const char *file, const char *path_model, const struct llama_model_params ¶ms); // Batch utils -void llama_batch_clear(struct llama_batch & batch); +void llama_batch_clear(struct llama_batch &batch); void llama_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits); + struct llama_batch &batch, + llama_token id, + llama_pos pos, + const std::vector &seq_ids, + bool logits); // // Vocab utils @@ -244,23 +253,23 @@ void llama_batch_add( // tokenizes a string into a vector of tokens // should work similar to Python's `tokenizer.encode` std::vector llama_tokenize( - const struct llama_context * ctx, - const std::string & text, - bool add_special, - bool parse_special = false); + const struct llama_context *ctx, + const std::string &text, + bool add_special, + bool parse_special = false); std::vector llama_tokenize( - const struct llama_model * model, - const std::string & text, - bool add_special, - bool parse_special = false); + const struct llama_model *model, + const std::string &text, + bool add_special, + bool parse_special = false); // tokenizes a token into a piece, optionally renders special/control tokens // should work similar to Python's `tokenizer.id_to_piece` std::string llama_token_to_piece( - const struct llama_context * ctx, - llama_token token, - bool special = true); + const struct llama_context *ctx, + llama_token token, + bool special = true); // TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function // that takes into account the tokenizer type and decides how to handle the leading space @@ -269,36 +278,36 @@ std::string llama_token_to_piece( // should work similar to Python's `tokenizer.decode` // removes the leading space from the first non-BOS token std::string llama_detokenize_spm( - llama_context * ctx, - const std::vector & tokens); + llama_context *ctx, + const std::vector &tokens); // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` std::string llama_detokenize_bpe( - llama_context * ctx, - const std::vector & tokens); + llama_context *ctx, + const std::vector &tokens); // Uses the value from the model metadata if possible, otherwise // defaults to true when model type is SPM, otherwise false. -bool llama_should_add_bos_token(const llama_model * model); +bool llama_should_add_bos_token(const llama_model *model); // // KV cache utils // // Dump the KV cache view with the number of sequences per cell. -void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); +void llama_kv_cache_dump_view(const llama_kv_cache_view &view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view &view, int row_size = 40); // // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n); +void llama_embd_normalize(const float *inp, float *out, int n); -float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); +float llama_embd_similarity_cos(const float *embd1, const float *embd2, int n); // // Control vector utils @@ -319,25 +328,24 @@ struct llama_control_vector_load_info { // Load control vectors, scale each by strength, and add them together. // On error, returns {-1, empty} -llama_control_vector_data llama_control_vector_load(const std::vector & load_infos); +llama_control_vector_data llama_control_vector_load(const std::vector &load_infos); // // Split utils // -static const char * const LLM_KV_SPLIT_NO = "split.no"; -static const char * const LLM_KV_SPLIT_COUNT = "split.count"; -static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; +static const char *const LLM_KV_SPLIT_NO = "split.no"; +static const char *const LLM_KV_SPLIT_COUNT = "split.count"; +static const char *const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; // // YAML utils // -void yaml_dump_vector_float (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); +void yaml_dump_vector_float(FILE *stream, const char *prop_name, const std::vector &data); +void yaml_dump_vector_int(FILE *stream, const char *prop_name, const std::vector &data); +void yaml_dump_string_multiline(FILE *stream, const char *prop_name, const char *data); void yaml_dump_non_result_info( - FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); - + FILE *stream, const gpt_params ¶ms, const llama_context *lctx, + const std::string ×tamp, const std::vector &prompt_tokens, const char *model_desc); diff --git a/src/bb/llm/grammar-parser.h b/src/bb/llm/grammar-parser.h index 9037d727..2309a083 100644 --- a/src/bb/llm/grammar-parser.h +++ b/src/bb/llm/grammar-parser.h @@ -17,13 +17,13 @@ #include namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; +struct parse_state { + std::map symbol_ids; + std::vector> rules; - std::vector c_rules(); - }; + std::vector c_rules(); +}; - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} +parse_state parse(const char *src); +void print_grammar(FILE *file, const parse_state &state); +} // namespace grammar_parser diff --git a/src/bb/llm/json-schema-to-grammar.h b/src/bb/llm/json-schema-to-grammar.h index 41623b34..ddc04f06 100644 --- a/src/bb/llm/json-schema-to-grammar.h +++ b/src/bb/llm/json-schema-to-grammar.h @@ -5,4 +5,4 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json &schema); diff --git a/src/bb/llm/llava.h b/src/bb/llm/llava.h index 19212f6e..4fa89d6d 100644 --- a/src/bb/llm/llava.h +++ b/src/bb/llm/llava.h @@ -4,17 +4,17 @@ #include "ggml.h" #ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define LLAVA_API __declspec(dllexport) -# else -# define LLAVA_API __declspec(dllimport) -# endif -# else -# define LLAVA_API __attribute__ ((visibility ("default"))) -# endif +#if defined(_WIN32) && !defined(__MINGW32__) +#ifdef LLAMA_BUILD +#define LLAVA_API __declspec(dllexport) #else -# define LLAVA_API +#define LLAVA_API __declspec(dllimport) +#endif +#else +#define LLAVA_API __attribute__((visibility("default"))) +#endif +#else +#define LLAVA_API #endif struct clip_ctx; @@ -24,24 +24,24 @@ extern "C" { #endif struct llava_image_embed { - float * embed; + float *embed; int n_image_pos; }; /** sanity check for clip <-> llava embed size match */ -LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip); +LLAVA_API bool llava_validate_embed_size(const struct llama_context *ctx_llama, const struct clip_ctx *ctx_clip); -LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); +LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx *ctx_clip, int n_threads, const struct clip_image_u8 *img, float **image_embd_out, int *n_img_pos_out); /** build an image embed from image file bytes */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); +LLAVA_API struct llava_image_embed *llava_image_embed_make_with_bytes(struct clip_ctx *ctx_clip, int n_threads, const unsigned char *image_bytes, int image_bytes_length); /** build an image embed from a path to an image filename */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); +LLAVA_API struct llava_image_embed *llava_image_embed_make_with_filename(struct clip_ctx *ctx_clip, int n_threads, const char *image_path); +LLAVA_API void llava_image_embed_free(struct llava_image_embed *embed); /** free an embedding made with llava_image_embed_make_* */ /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); +LLAVA_API bool llava_eval_image_embed(struct llama_context *ctx_llama, const struct llava_image_embed *embed, int n_batch, int *n_past); #ifdef __cplusplus } diff --git a/src/bb/llm/sampling.h b/src/bb/llm/sampling.h index f552487c..b8c705d4 100644 --- a/src/bb/llm/sampling.h +++ b/src/bb/llm/sampling.h @@ -13,36 +13,36 @@ // sampler types enum class llama_sampler_type : char { - TOP_K = 'k', - TOP_P = 'p', - MIN_P = 'm', - TFS_Z = 'f', - TYPICAL_P = 'y', + TOP_K = 'k', + TOP_P = 'p', + MIN_P = 'm', + TFS_Z = 'f', + TYPICAL_P = 'y', TEMPERATURE = 't' }; // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -50,20 +50,19 @@ typedef struct llama_sampling_params { llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, - llama_sampler_type::TEMPERATURE - }; + llama_sampler_type::TEMPERATURE}; std::string grammar; // optional BNF-like grammar to constrain sampling // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // how strong is guidance + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // how strong is guidance - std::unordered_map logit_bias; // logit bias for specific tokens + std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; - bool use_penalty_prompt_tokens = false; + bool use_penalty_prompt_tokens = false; } llama_sampling_params; // general sampler context @@ -75,15 +74,15 @@ struct llama_sampling_context { // mirostat sampler state float mirostat_mu; - llama_grammar * grammar; + llama_grammar *grammar; // internal grammar_parser::parse_state parsed_grammar; // TODO: replace with ring-buffer - std::vector prev; + std::vector prev; std::vector cur; - size_t n_valid; // Number of correct top tokens with correct probabilities. + size_t n_valid; // Number of correct top tokens with correct probabilities. std::mt19937 rng; }; @@ -91,37 +90,37 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); +struct llama_sampling_context *llama_sampling_init(const struct llama_sampling_params ¶ms); -void llama_sampling_free(struct llama_sampling_context * ctx); +void llama_sampling_free(struct llama_sampling_context *ctx); // Reset the sampler context // - clear prev tokens // - reset grammar -void llama_sampling_reset(llama_sampling_context * ctx); +void llama_sampling_reset(llama_sampling_context *ctx); // Set the sampler seed -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); +void llama_sampling_set_rng_seed(struct llama_sampling_context *ctx, uint32_t seed); // Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +void llama_sampling_cp(llama_sampling_context *src, llama_sampling_context *dst); // Get the last sampled token -llama_token llama_sampling_last(llama_sampling_context * ctx); +llama_token llama_sampling_last(llama_sampling_context *ctx); // Get a string representation of the last sampled tokens -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); +std::string llama_sampling_prev_str(llama_sampling_context *ctx_sampling, llama_context *ctx_main, int n); // Print sampling parameters into a string -std::string llama_sampling_print(const llama_sampling_params & params); +std::string llama_sampling_print(const llama_sampling_params ¶ms); // Print sampling order into a string -std::string llama_sampling_order_print(const llama_sampling_params & params); +std::string llama_sampling_order_print(const llama_sampling_params ¶ms); std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & names_string); +std::vector llama_sampling_types_from_names(const std::vector &names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string &names_string); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function @@ -141,22 +140,22 @@ std::vector llama_sampling_types_from_chars(const std::strin // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = -1); + struct llama_sampling_context *ctx_sampling, + struct llama_context *ctx_main, + struct llama_context *ctx_cfg, + int idx = -1); // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. llama_token_data_array llama_sampling_prepare( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = 0, - bool apply_grammar = true, - std::vector * original_logits = nullptr); + struct llama_sampling_context *ctx_sampling, + struct llama_context *ctx_main, + struct llama_context *ctx_cfg, + int idx = 0, + bool apply_grammar = true, + std::vector *original_logits = nullptr); void llama_sampling_accept( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - llama_token id, - bool apply_grammar); + struct llama_sampling_context *ctx_sampling, + struct llama_context *ctx_main, + llama_token id, + bool apply_grammar); diff --git a/src/bb/llm/stb_image.h b/src/bb/llm/stb_image.h index 4766d7e6..d0d2a24d 100644 --- a/src/bb/llm/stb_image.h +++ b/src/bb/llm/stb_image.h @@ -367,12 +367,12 @@ RECENT REVISION HISTORY: #ifndef STBI_NO_STDIO #include -#endif // STBI_NO_STDIO +#endif // STBI_NO_STDIO #define STBI_VERSION 1 enum { - STBI_default = 0, // only used for desired_channels + STBI_default = 0, // only used for desired_channels STBI_grey = 1, STBI_grey_alpha = 2, @@ -406,10 +406,10 @@ extern "C" { // typedef struct { - int (*read)(void * user, char * data, + int (*read)(void *user, char *data, int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip)(void * user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof)(void * user); // returns nonzero if we are at end of file/data + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -417,24 +417,24 @@ typedef struct { // 8-bits-per-channel interface // -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, + int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); // for stbi_load_from_file, file pointer is left pointing immediately after image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t *input); #endif //////////////////////////////////// @@ -442,14 +442,14 @@ STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wc // 16-bits-per-channel interface // -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, + int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); #endif //////////////////////////////////// @@ -457,53 +457,53 @@ STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * chann // float-per-channel interface // #ifndef STBI_NO_LINEAR -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * channels_in_file, - int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); #endif #endif #ifndef STBI_NO_HDR STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); STBIDEF void stbi_hdr_to_ldr_scale(float scale); -#endif // STBI_NO_HDR +#endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); STBIDEF void stbi_ldr_to_hdr_scale(float scale); -#endif // STBI_NO_LINEAR +#endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename); -STBIDEF int stbi_is_hdr_from_file(FILE * f); -#endif // STBI_NO_STDIO +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); +#endif // STBI_NO_STDIO // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char * stbi_failure_reason(void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free(void * retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * clbk, void * user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit(char const * filename); -STBIDEF int stbi_is_16_bit_from_file(FILE * f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif // for image formats that explicitly notate that they have premultiplied alpha, @@ -527,14 +527,14 @@ STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_fli // ZLIB client - used by PNG, available for other purposes -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen); -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header); -STBIDEF char * stbi_zlib_decode_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, + int parse_header); +STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); -STBIDEF char * stbi_zlib_decode_noheader_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -543,12 +543,12 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha // // //// end header file ///////////////////////////////////////////////////// -#endif // STBI_INCLUDE_STB_IMAGE_H +#endif // STBI_INCLUDE_STB_IMAGE_H #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ - defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) #ifndef STBI_ONLY_JPEG #define STBI_NO_JPEG @@ -585,12 +585,12 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha #include #include -#include // ptrdiff_t on osx +#include // ptrdiff_t on osx #include #include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -725,8 +725,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid static int stbi__cpuid3(void) { int info[4]; __cpuid(info, 1); @@ -753,7 +753,7 @@ static int stbi__sse2_available(void) { } #endif -#else // assume GCC-style if not VC++ +#else // assume GCC-style if not VC++ #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) @@ -801,7 +801,7 @@ typedef struct { int img_n, img_out_n; stbi_io_callbacks io; - void * io_user_data; + void *io_user_data; int read_from_callbacks; int buflen; @@ -812,10 +812,10 @@ typedef struct { stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; -static void stbi__refill_buffer(stbi__context * s); +static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context * s, stbi_uc const * buffer, int len) { +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { s->io.read = NULL; s->read_from_callbacks = 0; s->callback_already_read = 0; @@ -824,7 +824,7 @@ static void stbi__start_mem(stbi__context * s, stbi_uc const * buffer, int len) } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context * s, stbi_io_callbacks * c, void * user) { +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) { s->io = *c; s->io_user_data = user; s->buflen = sizeof(s->buffer_start); @@ -837,9 +837,11 @@ static void stbi__start_callbacks(stbi__context * s, stbi_io_callbacks * c, void #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void * user, char * data, int size) { return (int)fread(data, 1, size, (FILE *)user); } +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); +} -static void stbi__stdio_skip(void * user, int n) { +static void stbi__stdio_skip(void *user, int n) { int ch; fseek((FILE *)user, n, SEEK_CUR); ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ @@ -848,7 +850,9 @@ static void stbi__stdio_skip(void * user, int n) { } } -static int stbi__stdio_eof(void * user) { return feof((FILE *)user) || ferror((FILE *)user); } +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); +} static stbi_io_callbacks stbi__stdio_callbacks = { stbi__stdio_read, @@ -856,13 +860,15 @@ static stbi_io_callbacks stbi__stdio_callbacks = { stbi__stdio_eof, }; -static void stbi__start_file(stbi__context * s, FILE * f) { stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); +} // static void stop_file(stbi__context *s) { } -#endif // !STBI_NO_STDIO +#endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context * s) { +static void stbi__rewind(stbi__context *s) { // conceptually rewind SHOULD rewind to the beginning of the stream, // but we just rewind to the beginning of the initial buffer, because // we only use it after doing 'test', which only ever looks at at most 92 bytes @@ -870,7 +876,8 @@ static void stbi__rewind(stbi__context * s) { s->img_buffer_end = s->img_buffer_original_end; } -enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; +enum { STBI_ORDER_RGB, + STBI_ORDER_BGR }; typedef struct { int bits_per_channel; @@ -879,79 +886,83 @@ typedef struct { } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context * s); -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context * s); -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__png_is16(stbi__context * s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context * s); -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context * s); -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s); -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc); -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__psd_is16(stbi__context * s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context * s); -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context * s); -static void * stbi__pic_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context * s); -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp); -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s); -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__pnm_is16(stbi__context * s); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_is16(stbi__context *s); #endif static #ifdef STBI_THREAD_LOCAL STBI_THREAD_LOCAL #endif - const char * stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char * stbi_failure_reason(void) { return stbi__g_failure_reason; } +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; +} #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char * str) { +static int stbi__err(const char *str) { stbi__g_failure_reason = str; return 0; } #endif -static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); +} // stb_image uses ints pervasively, including for offset calculations. // therefore the largest decoded image size we can support with the @@ -981,7 +992,7 @@ static int stbi__mul2sizes_valid(int a, int b) { if (a < 0 || b < 0) return 0; if (b == 0) - return 1; // mul-by-0 is always safe + return 1; // mul-by-0 is always safe // portable way to check for no overflows in a*b return a <= INT_MAX / b; } @@ -1008,21 +1019,21 @@ static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { #if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void * stbi__malloc_mad2(int a, int b, int add) { +static void *stbi__malloc_mad2(int a, int b, int add) { if (!stbi__mad2sizes_valid(a, b, add)) return NULL; return stbi__malloc(a * b + add); } #endif -static void * stbi__malloc_mad3(int a, int b, int c, int add) { +static void *stbi__malloc_mad3(int a, int b, int c, int add) { if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) -static void * stbi__malloc_mad4(int a, int b, int c, int d, int add) { +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; return stbi__malloc(a * b * c * d + add); @@ -1032,20 +1043,20 @@ static void * stbi__malloc_mad4(int a, int b, int c, int d, int add) { // returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. static int stbi__addints_valid(int a, int b) { if ((a >= 0) != (b >= 0)) - return 1; // a and b have different signs, so no overflow + return 1; // a and b have different signs, so no overflow if (a < 0 && b < 0) - return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. return a <= INT_MAX - b; } // returns 1 if the product of two signed shorts is valid, 0 on overflow. static int stbi__mul2shorts_valid(short a, short b) { if (b == 0 || b == -1) - return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow if ((a >= 0) == (b >= 0)) - return a <= SHRT_MAX / b; // product is positive, so similar to mul2sizes_valid + return a <= SHRT_MAX / b; // product is positive, so similar to mul2sizes_valid if (b < 0) - return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN return a >= SHRT_MIN / b; } @@ -1064,14 +1075,16 @@ static int stbi__mul2shorts_valid(short a, short b) { #define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) #define stbi__errpuc(x, y) ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void * retval_from_stbi_load) { STBI_FREE(retval_from_stbi_load); } +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); +} #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; @@ -1090,14 +1103,14 @@ STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_fli stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load \ +#define stbi__vertically_flip_on_load \ (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local : stbi__vertically_flip_on_load_global) -#endif // STBI_THREAD_LOCAL +#endif // STBI_THREAD_LOCAL -static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order ri->num_channels = 0; // test the formats with a very explicit header first (at least a FOURCC @@ -1139,7 +1152,7 @@ static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, i #ifndef STBI_NO_HDR if (stbi__hdr_test(s)) { - float * hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); } #endif @@ -1153,47 +1166,47 @@ static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, i return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); } -static stbi_uc * stbi__convert_16_to_8(stbi__uint16 * orig, int w, int h, int channels) { +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) { int i; int img_len = w * h * channels; - stbi_uc * reduced; + stbi_uc *reduced; reduced = (stbi_uc *)stbi__malloc(img_len); if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling STBI_FREE(orig); return reduced; } -static stbi__uint16 * stbi__convert_8_to_16(stbi_uc * orig, int w, int h, int channels) { +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) { int i; int img_len = w * h * channels; - stbi__uint16 * enlarged; + stbi__uint16 *enlarged; enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); if (enlarged == NULL) return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff STBI_FREE(orig); return enlarged; } -static void stbi__vertical_flip(void * image, int w, int h, int bytes_per_pixel) { +static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) { int row; size_t bytes_per_row = (size_t)w * bytes_per_pixel; stbi_uc temp[2048]; - stbi_uc * bytes = (stbi_uc *)image; + stbi_uc *bytes = (stbi_uc *)image; for (row = 0; row < (h >> 1); row++) { - stbi_uc * row0 = bytes + row * bytes_per_row; - stbi_uc * row1 = bytes + (h - row - 1) * bytes_per_row; + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; // swap row0 with row1 size_t bytes_left = bytes_per_row; while (bytes_left) { @@ -1209,11 +1222,11 @@ static void stbi__vertical_flip(void * image, int w, int h, int bytes_per_pixel) } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void * image, int w, int h, int z, int bytes_per_pixel) { +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) { int slice; int slice_size = w * h * bytes_per_pixel; - stbi_uc * bytes = (stbi_uc *)image; + stbi_uc *bytes = (stbi_uc *)image; for (slice = 0; slice < z; ++slice) { stbi__vertical_flip(bytes, w, h, bytes_per_pixel); bytes += slice_size; @@ -1221,9 +1234,9 @@ static void stbi__vertical_flip_slices(void * image, int w, int h, int z, int by } #endif -static unsigned char * stbi__load_and_postprocess_8bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) { stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); if (result == NULL) return NULL; @@ -1246,9 +1259,9 @@ static unsigned char * stbi__load_and_postprocess_8bit(stbi__context * s, int * return (unsigned char *)result; } -static stbi__uint16 * stbi__load_and_postprocess_16bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) { stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); if (result == NULL) return NULL; @@ -1273,7 +1286,7 @@ static stbi__uint16 * stbi__load_and_postprocess_16bit(stbi__context * s, int * } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float * result, int * x, int * y, int * comp, int req_comp) { +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) { if (stbi__vertically_flip_on_load && result != NULL) { int channels = req_comp ? req_comp : *comp; stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); @@ -1284,21 +1297,21 @@ static void stbi__float_postprocess(float * result, int * x, int * y, int * comp #ifndef STBI_NO_STDIO #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char * str, - int cbmb, wchar_t * widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, + int cbmb, wchar_t *widestr, int cchwide); STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, - const wchar_t * widestr, int cchwide, char * str, int cbmb, - const char * defchar, int * used_default); + const wchar_t *widestr, int cchwide, char *str, int cbmb, + const char *defchar, int *used_default); #endif #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input) { +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t *input) { return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int)bufferlen, NULL, NULL); } #endif -static FILE * stbi__fopen(char const * filename, char const * mode) { - FILE * f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) wchar_t wMode[64]; wchar_t wFilename[1024]; @@ -1324,9 +1337,9 @@ static FILE * stbi__fopen(char const * filename, char const * mode) { return f; } -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - unsigned char * result; +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); result = stbi_load_from_file(f, x, y, comp, req_comp); @@ -1334,8 +1347,8 @@ STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * comp, return result; } -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { - unsigned char * result; +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) { + unsigned char *result; stbi__context s; stbi__start_file(&s, f); result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); @@ -1346,8 +1359,8 @@ STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * comp, in return result; } -STBIDEF stbi__uint16 * stbi_load_from_file_16(FILE * f, int * x, int * y, int * comp, int req_comp) { - stbi__uint16 * result; +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) { + stbi__uint16 *result; stbi__context s; stbi__start_file(&s, f); result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); @@ -1358,9 +1371,9 @@ STBIDEF stbi__uint16 * stbi_load_from_file_16(FILE * f, int * x, int * y, int * return result; } -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - stbi__uint16 * result; +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; if (!f) return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); result = stbi_load_from_file_16(f, x, y, comp, req_comp); @@ -1368,39 +1381,39 @@ STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * co return result; } -#endif //! STBI_NO_STDIO +#endif //! STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels) { +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, + int desired_channels) { stbi__context s; stbi__start_mem(&s, buffer, len); return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); } -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels) { +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, + int *channels_in_file, int desired_channels) { stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); } -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) { stbi__context s; stbi__start_mem(&s, buffer, len); return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, + int req_comp) { stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp) { - unsigned char * result; +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; stbi__context s; stbi__start_mem(&s, buffer, len); @@ -1414,12 +1427,12 @@ STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int #endif #ifndef STBI_NO_LINEAR -static float * stbi__loadf_main(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - unsigned char * data; +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) { + unsigned char *data; #ifndef STBI_NO_HDR if (stbi__hdr_test(s)) { stbi__result_info ri; - float * hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); if (hdr_data) stbi__float_postprocess(hdr_data, x, y, comp, req_comp); return hdr_data; @@ -1431,23 +1444,23 @@ static float * stbi__loadf_main(stbi__context * s, int * x, int * y, int * comp, return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); } -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) { stbi__context s; stbi__start_mem(&s, buffer, len); return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, + int req_comp) { stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); return stbi__loadf_main(&s, x, y, comp, req_comp); } #ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * comp, int req_comp) { - float * result; - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); if (!f) return stbi__errpf("can't fopen", "Unable to open file"); result = stbi_loadf_from_file(f, x, y, comp, req_comp); @@ -1455,20 +1468,20 @@ STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * comp, return result; } -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) { stbi__context s; stbi__start_file(&s, f); return stbi__loadf_main(&s, x, y, comp, req_comp); } -#endif // !STBI_NO_STDIO +#endif // !STBI_NO_STDIO -#endif // !STBI_NO_LINEAR +#endif // !STBI_NO_LINEAR // these is-hdr-or-not is defined independent of whether STBI_NO_LINEAR is // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len) { +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { #ifndef STBI_NO_HDR stbi__context s; stbi__start_mem(&s, buffer, len); @@ -1481,8 +1494,8 @@ STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len) { } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); int result = 0; if (f) { result = stbi_is_hdr_from_file(f); @@ -1491,7 +1504,7 @@ STBIDEF int stbi_is_hdr(char const * filename) { return result; } -STBIDEF int stbi_is_hdr_from_file(FILE * f) { +STBIDEF int stbi_is_hdr_from_file(FILE *f) { #ifndef STBI_NO_HDR long pos = ftell(f); int res; @@ -1505,9 +1518,9 @@ STBIDEF int stbi_is_hdr_from_file(FILE * f) { return 0; #endif } -#endif // !STBI_NO_STDIO +#endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user) { +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) { #ifndef STBI_NO_HDR stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); @@ -1522,23 +1535,33 @@ STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * us #ifndef STBI_NO_LINEAR static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1 / gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1 / scale; } +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; +enum { STBI__SCAN_load = 0, + STBI__SCAN_type, + STBI__SCAN_header }; -static void stbi__refill_buffer(stbi__context * s) { +static void stbi__refill_buffer(stbi__context *s) { int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); if (n == 0) { @@ -1554,7 +1577,7 @@ static void stbi__refill_buffer(stbi__context * s) { } } -stbi_inline static stbi_uc stbi__get8(stbi__context * s) { +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { if (s->img_buffer < s->img_buffer_end) return *s->img_buffer++; if (s->read_from_callbacks) { @@ -1567,7 +1590,7 @@ stbi_inline static stbi_uc stbi__get8(stbi__context * s) { #if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context * s) { +stbi_inline static int stbi__at_eof(stbi__context *s) { if (s->io.read) { if (!(s->io.eof)(s->io_user_data)) return 0; @@ -1581,13 +1604,13 @@ stbi_inline static int stbi__at_eof(stbi__context * s) { } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ defined(STBI_NO_GIF) && defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context * s, int n) { +static void stbi__skip(stbi__context *s, int n) { if (n == 0) - return; // already there! + return; // already there! if (n < 0) { s->img_buffer = s->img_buffer_end; return; @@ -1607,7 +1630,7 @@ static void stbi__skip(stbi__context * s, int n) { #if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context * s, stbi_uc * buffer, int n) { +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { if (s->io.read) { int blen = (int)(s->img_buffer_end - s->img_buffer); if (blen < n) { @@ -1634,7 +1657,7 @@ static int stbi__getn(stbi__context * s, stbi_uc * buffer, int n) { #if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context * s) { +static int stbi__get16be(stbi__context *s) { int z = stbi__get8(s); return (z << 8) + stbi__get8(s); } @@ -1643,7 +1666,7 @@ static int stbi__get16be(stbi__context * s) { #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context * s) { +static stbi__uint32 stbi__get32be(stbi__context *s) { stbi__uint32 z = stbi__get16be(s); return (z << 16) + stbi__get16be(s); } @@ -1652,23 +1675,23 @@ static stbi__uint32 stbi__get32be(stbi__context * s) { #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context * s) { +static int stbi__get16le(stbi__context *s) { int z = stbi__get8(s); return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context * s) { +static stbi__uint32 stbi__get32le(stbi__context *s) { stbi__uint32 z = stbi__get16le(s); z += (stbi__uint32)stbi__get16le(s) << 16; return z; } #endif -#define STBI__BYTECAST(x) ((stbi_uc)((x)&255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) ((stbi_uc)((x)&255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else @@ -1683,16 +1706,18 @@ static stbi__uint32 stbi__get32le(stbi__context * s) { // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) { return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); +} #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int req_comp, unsigned int x, unsigned int y) { +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) { int i, j; - unsigned char * good; + unsigned char *good; if (req_comp == img_n) return data; @@ -1705,12 +1730,12 @@ static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int } for (j = 0; j < (int)y; ++j) { - unsigned char * src = data + j * x * img_n; - unsigned char * dest = good + j * x * req_comp; + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; #define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ for (i = x - 1; i >= 0; --i, src += a, dest += b) // convert source image with img_n components to one with req_comp components; // avoid switch per pixel, so use switch per scanline and massive macros @@ -1720,16 +1745,22 @@ static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int dest[1] = 255; } break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } break; STBI__CASE(1, 4) { dest[0] = dest[1] = dest[2] = src[0]; dest[3] = 255; } break; - STBI__CASE(2, 1) { dest[0] = src[0]; } + STBI__CASE(2, 1) { + dest[0] = src[0]; + } break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } break; STBI__CASE(2, 4) { dest[0] = dest[1] = dest[2] = src[0]; @@ -1743,14 +1774,18 @@ static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int dest[3] = 255; } break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } break; STBI__CASE(3, 2) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); dest[1] = 255; } break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } break; STBI__CASE(4, 2) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); @@ -1780,15 +1815,17 @@ static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); +} #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int req_comp, unsigned int x, unsigned int y) { +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) { int i, j; - stbi__uint16 * good; + stbi__uint16 *good; if (req_comp == img_n) return data; @@ -1801,12 +1838,12 @@ static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int } for (j = 0; j < (int)y; ++j) { - stbi__uint16 * src = data + j * x * img_n; - stbi__uint16 * dest = good + j * x * req_comp; + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; #define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ for (i = x - 1; i >= 0; --i, src += a, dest += b) // convert source image with img_n components to one with req_comp components; // avoid switch per pixel, so use switch per scanline and massive macros @@ -1816,16 +1853,22 @@ static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int dest[1] = 0xffff; } break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } break; STBI__CASE(1, 4) { dest[0] = dest[1] = dest[2] = src[0]; dest[3] = 0xffff; } break; - STBI__CASE(2, 1) { dest[0] = src[0]; } + STBI__CASE(2, 1) { + dest[0] = src[0]; + } break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } break; STBI__CASE(2, 4) { dest[0] = dest[1] = dest[2] = src[0]; @@ -1839,14 +1882,18 @@ static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int dest[3] = 0xffff; } break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } break; STBI__CASE(3, 2) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } break; STBI__CASE(4, 2) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); @@ -1874,9 +1921,9 @@ static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int #endif #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp) { +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { int i, k, n; - float * output; + float *output; if (!data) return NULL; output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); @@ -1906,9 +1953,9 @@ static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp) { #ifndef STBI_NO_HDR #define stbi__float2int(x) ((int)(x)) -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { int i, k, n; - stbi_uc * output; + stbi_uc *output; if (!data) return NULL; output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); @@ -1968,7 +2015,7 @@ static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache typedef struct { stbi_uc fast[1 << FAST_BITS]; @@ -1977,11 +2024,11 @@ typedef struct { stbi_uc values[256]; stbi_uc size[257]; unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; typedef struct { - stbi__context * s; + stbi__context *s; stbi__huffman huff_dc[4]; stbi__huffman huff_ac[4]; stbi__uint16 dequant[4][64]; @@ -2001,17 +2048,17 @@ typedef struct { int dc_pred; int x, y, w2, h2; - stbi_uc * data; + stbi_uc *data; void *raw_data, *raw_coeff; - stbi_uc * linebuf; - short * coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks } img_comp[4]; - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop int progressive; int spec_start; @@ -2020,20 +2067,20 @@ typedef struct { int succ_low; int eob_run; int jfif; - int app14_color_transform; // Adobe APP14 tag + int app14_color_transform; // Adobe APP14 tag int rgb; int scan_n, order[4]; int restart_interval, todo; // kernels - void (*idct_block_kernel)(stbi_uc * out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc * (*resample_row_hv_2_kernel)(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman * h, int * count) { +static int stbi__build_huffman(stbi__huffman *h, int *count) { int i, j, k = 0; unsigned int code; // build size list for each symbol (from JPEG spec) @@ -2081,7 +2128,7 @@ static int stbi__build_huffman(stbi__huffman * h, int * count) { // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 * fast_ac, stbi__huffman * h) { +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { int i; for (i = 0; i < (1 << FAST_BITS); ++i) { stbi_uc fast = h->fast[i]; @@ -2106,13 +2153,13 @@ static void stbi__build_fast_ac(stbi__int16 * fast_ac, stbi__huffman * h) { } } -static void stbi__grow_buffer_unsafe(stbi__jpeg * j) { +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { do { unsigned int b = j->nomore ? 0 : stbi__get8(j->s); if (b == 0xff) { int c = stbi__get8(j->s); while (c == 0xff) - c = stbi__get8(j->s); // consume fill bytes + c = stbi__get8(j->s); // consume fill bytes if (c != 0) { j->marker = (unsigned char)c; j->nomore = 1; @@ -2125,11 +2172,11 @@ static void stbi__grow_buffer_unsafe(stbi__jpeg * j) { } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17] = {0, 1, 3, 7, 15, 31, 63, 127, 255, +static const stbi__uint32 stbi__bmask[17] = {0, 1, 3, 7, 15, 31, 63, 127, 255, 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg * j, stbi__huffman * h) { +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { unsigned int temp; int c, k; @@ -2170,7 +2217,7 @@ stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg * j, stbi__huffman * h) // convert the huffman code to the symbol id c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - if (c < 0 || c >= 256) // symbol id out of bounds! + if (c < 0 || c >= 256) // symbol id out of bounds! return -1; STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); @@ -2185,15 +2232,15 @@ static const int stbi__jbias[16] = {0, -1, -3, -7, -15, -31, -63, -127, -255, -5 // combined JPEG 'receive' and JPEG 'extend', since baseline // always extends everything it receives. -stbi_inline static int stbi__extend_receive(stbi__jpeg * j, int n) { +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { unsigned int k; int sgn; if (j->code_bits < n) stbi__grow_buffer_unsafe(j); if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing + return 0; // ran out of bits from stream, return 0s intead of continuing - sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) k = stbi_lrot(j->code_buffer, n); j->code_buffer = k & ~stbi__bmask[n]; k &= stbi__bmask[n]; @@ -2202,12 +2249,12 @@ stbi_inline static int stbi__extend_receive(stbi__jpeg * j, int n) { } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg * j, int n) { +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { unsigned int k; if (j->code_bits < n) stbi__grow_buffer_unsafe(j); if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing + return 0; // ran out of bits from stream, return 0s intead of continuing k = stbi_lrot(j->code_buffer, n); j->code_buffer = k & ~stbi__bmask[n]; k &= stbi__bmask[n]; @@ -2215,12 +2262,12 @@ stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg * j, int n) { return k; } -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg * j) { +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { unsigned int k; if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); if (j->code_bits < 1) - return 0; // ran out of bits from stream, return 0s intead of continuing + return 0; // ran out of bits from stream, return 0s intead of continuing k = j->code_buffer; j->code_buffer <<= 1; --j->code_bits; @@ -2236,8 +2283,8 @@ static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman * hdc, stbi__huffman * hac, stbi__int16 * fac, - int b, stbi__uint16 * dequant) { +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, + int b, stbi__uint16 *dequant) { int diff, dc, k; int t; @@ -2268,9 +2315,9 @@ static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman stbi__grow_buffer_unsafe(j); c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); j->code_buffer <<= s; @@ -2286,7 +2333,7 @@ static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman r = rs >> 4; if (s == 0) { if (rs != 0xf0) - break; // end block + break; // end block k += 16; } else { k += r; @@ -2299,7 +2346,7 @@ static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman return 1; } -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi__huffman * hdc, int b) { +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) { int diff, dc; int t; if (j->spec_end != 0) @@ -2310,7 +2357,7 @@ static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi_ if (j->succ_high == 0) { // first scan for DC coefficient, must be first - memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now t = stbi__jpeg_huff_decode(j, hdc); if (t < 0 || t > 15) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); @@ -2333,7 +2380,7 @@ static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi_ // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi__huffman * hac, stbi__int16 * fac) { +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) { int k; if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); @@ -2354,9 +2401,9 @@ static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi_ stbi__grow_buffer_unsafe(j); c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); j->code_buffer <<= s; @@ -2393,7 +2440,7 @@ static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi_ if (j->eob_run) { --j->eob_run; for (k = j->spec_start; k <= j->spec_end; ++k) { - short * p = &data[stbi__jpeg_dezigzag[k]]; + short *p = &data[stbi__jpeg_dezigzag[k]]; if (*p != 0) if (stbi__jpeg_get_bit(j)) if ((*p & bit) == 0) { @@ -2408,7 +2455,7 @@ static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi_ do { int r, s; int rs = stbi__jpeg_huff_decode( - j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh if (rs < 0) return stbi__err("bad huffman code", "Corrupt JPEG"); s = rs & 15; @@ -2418,7 +2465,7 @@ static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi_ j->eob_run = (1 << r) - 1; if (r) j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block + r = 64; // force end of block } else { // r=15 s=0 should write 16 0s, so we just do // a run of 15 0s and then write s (which is 0), @@ -2436,7 +2483,7 @@ static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi_ // advance by r while (k <= j->spec_end) { - short * p = &data[stbi__jpeg_dezigzag[k++]]; + short *p = &data[stbi__jpeg_dezigzag[k++]]; if (*p != 0) { if (stbi__jpeg_get_bit(j)) if ((*p & bit) == 0) { @@ -2475,47 +2522,47 @@ stbi_inline static stbi_uc stbi__clamp(int x) { #define stbi__fsh(x) ((x)*4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ - int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ - t3 = p1 + p2 * stbi__f2f(0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2 + p3); \ - t1 = stbi__fsh(p2 - p3); \ - x0 = t0 + t3; \ - x3 = t0 - t3; \ - x1 = t1 + t2; \ - x2 = t1 - t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0 + t2; \ - p4 = t1 + t3; \ - p1 = t0 + t3; \ - p2 = t1 + t2; \ - p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ - t0 = t0 * stbi__f2f(0.298631336f); \ - t1 = t1 * stbi__f2f(2.053119869f); \ - t2 = t2 * stbi__f2f(3.072711026f); \ - t3 = t3 * stbi__f2f(1.501321110f); \ - p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ - p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ - p3 = p3 * stbi__f2f(-1.961570560f); \ - p4 = p4 * stbi__f2f(-0.390180644f); \ - t3 += p1 + p4; \ - t2 += p2 + p3; \ - t1 += p2 + p4; \ +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ t0 += p1 + p3; -static void stbi__idct_block(stbi_uc * out, int out_stride, short data[64]) { +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { int i, val[64], *v = val; - stbi_uc * o; - short * d = data; + stbi_uc *o; + short *d = data; // columns for (i = 0; i < 8; ++i, ++d, ++v) { @@ -2576,7 +2623,7 @@ static void stbi__idct_block(stbi_uc * out, int out_stride, short data[64]) { // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { // This is constructed to match our regular (generic) integer IDCT exactly. __m128i row0, row1, row2, row3, row4, row5, row6, row7; __m128i tmp; @@ -2586,78 +2633,78 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) // out(1) = c1[even]*x + c1[odd]*y -#define dct_rot(out0, out1, x, y, c0, c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) // out = in << 12 (in 16-bit, out 32-bit) -#define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ +#define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) // wide add -#define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ __m128i out##_h = _mm_add_epi32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) // butterfly a/b, add bias, then shift by "s" and pack -#define dct_bfly32o(out0, out1, a, b, bias, s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ } // 8-bit interleave step (for transposes) -#define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ b = _mm_unpackhi_epi8(tmp, b) // 16-bit interleave step (for transposes) -#define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ b = _mm_unpackhi_epi16(tmp, b) -#define dct_pass(bias, shift) \ - { \ - /* even part */ \ - dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ - dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0, row7, x0, x7, bias, shift); \ - dct_bfly32o(row1, row6, x1, x6, bias, shift); \ - dct_bfly32o(row2, row5, x2, x5, bias, shift); \ - dct_bfly32o(row3, row4, x3, x4, bias, shift); \ +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ } __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); @@ -2711,22 +2758,22 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { { // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 __m128i p1 = _mm_packus_epi16(row2, row3); __m128i p2 = _mm_packus_epi16(row4, row5); __m128i p3 = _mm_packus_epi16(row6, row7); // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... // store _mm_storel_epi64((__m128i *)out, p0); @@ -2757,13 +2804,13 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { #undef dct_pass } -#endif // STBI_SSE2 +#endif // STBI_SSE2 #ifdef STBI_NEON // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); @@ -2779,75 +2826,75 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ - dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ - dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ - dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ } // load @@ -2870,40 +2917,40 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) \ - { \ - int16x8x2_t t = vtrnq_s16(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ - } -#define dct_trn32(x, y) \ - { \ - int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ - x = vreinterpretq_s16_s32(t.val[0]); \ - y = vreinterpretq_s16_s32(t.val[1]); \ - } -#define dct_trn64(x, y) \ - { \ - int16x8_t x0 = x; \ - int16x8_t y0 = y; \ - x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ - y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ } // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 dct_trn16(row2, row3); dct_trn16(row4, row5); dct_trn16(row6, row7); // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 dct_trn32(row1, row3); dct_trn32(row4, row6); dct_trn32(row5, row7); // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 dct_trn64(row1, row5); dct_trn64(row2, row6); dct_trn64(row3, row7); @@ -2931,23 +2978,23 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { uint8x8_t p7 = vqrshrun_n_s16(row7, 1); // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) \ - { \ - uint8x8x2_t t = vtrn_u8(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ } -#define dct_trn8_16(x, y) \ - { \ - uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ - x = vreinterpret_u8_u16(t.val[0]); \ - y = vreinterpret_u8_u16(t.val[1]); \ +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ } -#define dct_trn8_32(x, y) \ - { \ - uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ - x = vreinterpret_u8_u32(t.val[0]); \ - y = vreinterpret_u8_u32(t.val[1]); \ +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ } // sadly can't use interleaved stores here since we only write @@ -3002,13 +3049,13 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { #undef dct_pass } -#endif // STBI_NEON +#endif // STBI_NEON #define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg * j) { +static stbi_uc stbi__get_marker(stbi__jpeg *j) { stbi_uc x; if (j->marker != STBI__MARKER_none) { x = j->marker; @@ -3019,7 +3066,7 @@ static stbi_uc stbi__get_marker(stbi__jpeg * j) { if (x != 0xff) return STBI__MARKER_none; while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes + x = stbi__get8(j->s); // consume repeated 0xff fill bytes return x; } @@ -3029,7 +3076,7 @@ static stbi_uc stbi__get_marker(stbi__jpeg * j) { // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg * j) { +static void stbi__jpeg_reset(stbi__jpeg *j) { j->code_bits = 0; j->code_buffer = 0; j->nomore = 0; @@ -3041,7 +3088,7 @@ static void stbi__jpeg_reset(stbi__jpeg * j) { // since we don't even allow 1<<30 pixels } -static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { stbi__jpeg_reset(z); if (!z->progressive) { if (z->scan_n == 1) { @@ -3074,7 +3121,7 @@ static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { } } return 1; - } else { // interleaved + } else { // interleaved int i, j, k, x, y; STBI_SIMD_ALIGN(short, data[64]); for (j = 0; j < z->img_mcu_y; ++j) { @@ -3122,7 +3169,7 @@ static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { int h = (z->img_comp[n].y + 7) >> 3; for (j = 0; j < h; ++j) { for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); if (z->spec_start == 0) { if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) return 0; @@ -3142,7 +3189,7 @@ static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { } } return 1; - } else { // interleaved + } else { // interleaved int i, j, k, x, y; for (j = 0; j < z->img_mcu_y; ++j) { for (i = 0; i < z->img_mcu_x; ++i) { @@ -3155,7 +3202,7 @@ static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { for (x = 0; x < z->img_comp[n].h; ++x) { int x2 = (i * z->img_comp[n].h + x); int y2 = (j * z->img_comp[n].v + y); - short * data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) return 0; } @@ -3177,13 +3224,13 @@ static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { } } -static void stbi__jpeg_dequantize(short * data, stbi__uint16 * dequant) { +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { int i; for (i = 0; i < 64; ++i) data[i] *= dequant[i]; } -static void stbi__jpeg_finish(stbi__jpeg * z) { +static void stbi__jpeg_finish(stbi__jpeg *z) { if (z->progressive) { // dequantize and idct the data int i, j, n; @@ -3192,7 +3239,7 @@ static void stbi__jpeg_finish(stbi__jpeg * z) { int h = (z->img_comp[n].y + 7) >> 3; for (j = 0; j < h; ++j) { for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); } @@ -3201,19 +3248,19 @@ static void stbi__jpeg_finish(stbi__jpeg * z) { } } -static int stbi__process_marker(stbi__jpeg * z, int m) { +static int stbi__process_marker(stbi__jpeg *z, int m) { int L; switch (m) { - case STBI__MARKER_none: // no marker found + case STBI__MARKER_none: // no marker found return stbi__err("expected marker", "Corrupt JPEG"); - case 0xDD: // DRI - specify restart interval + case 0xDD: // DRI - specify restart interval if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len", "Corrupt JPEG"); z->restart_interval = stbi__get16be(z->s); return 1; - case 0xDB: // DQT - define quantization table + case 0xDB: // DQT - define quantization table L = stbi__get16be(z->s) - 2; while (L > 0) { int q = stbi__get8(z->s); @@ -3230,10 +3277,10 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { } return L == 0; - case 0xC4: // DHT - define huffman table + case 0xC4: // DHT - define huffman table L = stbi__get16be(z->s) - 2; while (L > 0) { - stbi_uc * v; + stbi_uc *v; int sizes[16], i, n = 0; int q = stbi__get8(z->s); int tc = q >> 4; @@ -3245,7 +3292,7 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { n += sizes[i]; } if (n > 256) - return stbi__err("bad DHT header", "Corrupt JPEG"); // Loop over i < n would write past end of values! + return stbi__err("bad DHT header", "Corrupt JPEG"); // Loop over i < n would write past end of values! L -= 17; if (tc == 0) { if (!stbi__build_huffman(z->huff_dc + th, sizes)) @@ -3276,7 +3323,7 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { } L -= 2; - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; int ok = 1; int i; @@ -3286,7 +3333,7 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { L -= 5; if (ok) z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; int ok = 1; int i; @@ -3295,10 +3342,10 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { ok = 0; L -= 6; if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform L -= 6; } } @@ -3311,7 +3358,7 @@ static int stbi__process_marker(stbi__jpeg * z, int m) { } // after we see SOS -static int stbi__process_scan_header(stbi__jpeg * z) { +static int stbi__process_scan_header(stbi__jpeg *z) { int i; int Ls = stbi__get16be(z->s); z->scan_n = stbi__get8(z->s); @@ -3326,7 +3373,7 @@ static int stbi__process_scan_header(stbi__jpeg * z) { if (z->img_comp[which].id == id) break; if (which == z->s->img_n) - return 0; // no match + return 0; // no match z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff", "Corrupt JPEG"); @@ -3339,7 +3386,7 @@ static int stbi__process_scan_header(stbi__jpeg * z) { { int aa; z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 aa = stbi__get8(z->s); z->succ_high = (aa >> 4); z->succ_low = (aa & 15); @@ -3358,7 +3405,7 @@ static int stbi__process_scan_header(stbi__jpeg * z) { return 1; } -static int stbi__free_jpeg_components(stbi__jpeg * z, int ncomp, int why) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { int i; for (i = 0; i < ncomp; ++i) { if (z->img_comp[i].raw_data) { @@ -3379,22 +3426,22 @@ static int stbi__free_jpeg_components(stbi__jpeg * z, int ncomp, int why) { return why; } -static int stbi__process_frame_header(stbi__jpeg * z, int scan) { - stbi__context * s = z->s; +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; int Lf, p, i, q, h_max = 1, v_max = 1, c; Lf = stbi__get16be(s); if (Lf < 11) - return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG p = stbi__get8(s); if (p != 8) - return stbi__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline + return stbi__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", - "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG s->img_x = stbi__get16be(s); if (s->img_x == 0) - return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large", "Very large image (corrupt?)"); if (s->img_x > STBI_MAX_DIMENSIONS) @@ -3504,11 +3551,11 @@ static int stbi__process_frame_header(stbi__jpeg * z, int scan) { #define stbi__SOF_progressive(x) ((x) == 0xc2) -static int stbi__decode_jpeg_header(stbi__jpeg * z, int scan) { +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { int m; z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty m = stbi__get_marker(z); if (!stbi__SOI(m)) return stbi__err("no SOI", "Corrupt JPEG"); @@ -3532,12 +3579,12 @@ static int stbi__decode_jpeg_header(stbi__jpeg * z, int scan) { return 1; } -static int stbi__skip_jpeg_junk_at_end(stbi__jpeg * j) { +static int stbi__skip_jpeg_junk_at_end(stbi__jpeg *j) { // some JPEGs have junk at end, skip over it but if we find what looks // like a valid marker, resume there while (!stbi__at_eof(j->s)) { int x = stbi__get8(j->s); - while (x == 255) { // might be a marker + while (x == 255) { // might be a marker if (stbi__at_eof(j->s)) return STBI__MARKER_none; x = stbi__get8(j->s); @@ -3555,7 +3602,7 @@ static int stbi__skip_jpeg_junk_at_end(stbi__jpeg * j) { } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg * j) { +static int stbi__decode_jpeg_image(stbi__jpeg *j) { int m; for (m = 0; m < 4; m++) { j->img_comp[m].raw_data = NULL; @@ -3599,11 +3646,11 @@ static int stbi__decode_jpeg_image(stbi__jpeg * j) { // static jfif-centered resampling (across block boundaries) -typedef stbi_uc * (*resample_row_func)(stbi_uc * out, stbi_uc * in0, stbi_uc * in1, int w, int hs); +typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, int w, int hs); #define stbi__div4(x) ((stbi_uc)((x) >> 2)) -static stbi_uc * resample_row_1(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { STBI_NOTUSED(out); STBI_NOTUSED(in_far); STBI_NOTUSED(w); @@ -3611,7 +3658,7 @@ static stbi_uc * resample_row_1(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_f return in_near; } -static stbi_uc * stbi__resample_row_v_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { // need to generate two samples vertically for every one in input int i; STBI_NOTUSED(hs); @@ -3620,10 +3667,10 @@ static stbi_uc * stbi__resample_row_v_2(stbi_uc * out, stbi_uc * in_near, stbi_u return out; } -static stbi_uc * stbi__resample_row_h_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { // need to generate two samples horizontally for every one in input int i; - stbi_uc * input = in_near; + stbi_uc *input = in_near; if (w == 1) { // if only one sample, can't do any interpolation @@ -3649,7 +3696,7 @@ static stbi_uc * stbi__resample_row_h_2(stbi_uc * out, stbi_uc * in_near, stbi_u #define stbi__div16(x) ((stbi_uc)((x) >> 4)) -static stbi_uc * stbi__resample_row_hv_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { // need to generate 2x2 samples for every one in input int i, t0, t1; if (w == 1) { @@ -3673,7 +3720,7 @@ static stbi_uc * stbi__resample_row_hv_2(stbi_uc * out, stbi_uc * in_near, stbi_ } #if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { // need to generate 2x2 samples for every one in input int i = 0, t0, t1; @@ -3697,7 +3744,7 @@ static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, __m128i nearw = _mm_unpacklo_epi8(nearb, zero); __m128i diff = _mm_sub_epi16(farw, nearw); __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row + __m128i curr = _mm_add_epi16(nears, diff); // current row // horizontal filter works the same based on shifted vers of current // row. "prev" is current row shifted right by 1 pixel; we need to @@ -3737,7 +3784,7 @@ static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, uint8x8_t nearb = vld1_u8(in_near + i); int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row + int16x8_t curr = vaddq_s16(nears, diff); // current row // horizontal filter works the same based on shifted vers of current // row. "prev" is current row shifted right by 1 pixel; we need to @@ -3788,7 +3835,7 @@ static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, } #endif -static stbi_uc * stbi__resample_row_generic(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) { // resample with nearest-neighbor int i, j; STBI_NOTUSED(in_far); @@ -3801,11 +3848,11 @@ static stbi_uc * stbi__resample_row_generic(stbi_uc * out, stbi_uc * in_near, st // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar #define stbi__float2fixed(x) (((int)((x)*4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) { int i; for (i = 0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int y_fixed = (y[i] << 20) + (1 << 19); // rounding int r, g, b; int cr = pcr[i] - 128; int cb = pcb[i] - 128; @@ -3842,7 +3889,7 @@ static void stbi__YCbCr_to_RGB_row(stbi_uc * out, const stbi_uc * y, const stbi_ } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc const * pcb, stbi_uc const * pcr, int count, +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) { int i = 0; @@ -3858,15 +3905,15 @@ static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc co __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); - __m128i xw = _mm_set1_epi16(255); // alpha channel + __m128i xw = _mm_set1_epi16(255); // alpha channel for (; i + 7 < count; i += 8) { // load __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 // unpack to short (and left-shift cr, cb by 8) __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); @@ -3954,7 +4001,7 @@ static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc co #endif for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int y_fixed = (y[i] << 20) + (1 << 19); // rounding int r, g, b; int cr = pcr[i] - 128; int cb = pcb[i] - 128; @@ -3992,7 +4039,7 @@ static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc co #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg * j) { +static void stbi__setup_jpeg(stbi__jpeg *j) { j->idct_block_kernel = stbi__idct_block; j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; @@ -4013,15 +4060,17 @@ static void stbi__setup_jpeg(stbi__jpeg * j) { } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg * j) { stbi__free_jpeg_components(j, j->s->img_n, 0); } +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); +} typedef struct { resample_row_func resample; stbi_uc *line0, *line1; - int hs, vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication @@ -4030,9 +4079,9 @@ static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { return (stbi_uc)((t + (t >> 8)) >> 8); } -static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * comp, int req_comp) { +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) { int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe + z->s->img_n = 0; // make stbi__cleanup_jpeg safe // validate req_comp if (req_comp < 0 || req_comp > 4) @@ -4045,7 +4094,8 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * } // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : + 1; is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); @@ -4065,13 +4115,13 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * { int k; unsigned int i, j; - stbi_uc * output; - stbi_uc * coutput[4] = {NULL, NULL, NULL, NULL}; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; stbi__resample res_comp[4]; for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; + stbi__resample *r = &res_comp[k]; // allocate line buffer big enough for upsampling off the edges // with upsample factor of 4 @@ -4109,9 +4159,9 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * // now go ahead and resample for (j = 0; j < z->s->img_y; ++j) { - stbi_uc * out = output + n * z->s->img_x * j; + stbi_uc *out = output + n * z->s->img_x * j; for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; + stbi__resample *r = &res_comp[k]; int y_bot = r->ystep >= (r->vs >> 1); coutput[k] = r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, y_bot ? r->line0 : r->line1, r->w_lores, r->hs); @@ -4123,7 +4173,7 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * } } if (n >= 3) { - stbi_uc * y = coutput[0]; + stbi_uc *y = coutput[0]; if (z->s->img_n == 3) { if (is_rgb) { for (i = 0; i < z->s->img_x; ++i) { @@ -4137,7 +4187,7 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); } } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK + if (z->app14_color_transform == 0) { // CMYK for (i = 0; i < z->s->img_x; ++i) { stbi_uc m = coutput[3][i]; out[0] = stbi__blinn_8x8(coutput[0][i], m); @@ -4146,7 +4196,7 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * out[3] = 255; out += n; } - } else if (z->app14_color_transform == 2) { // YCCK + } else if (z->app14_color_transform == 2) { // YCCK z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); for (i = 0; i < z->s->img_x; ++i) { stbi_uc m = coutput[3][i]; @@ -4155,13 +4205,13 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * out[2] = stbi__blinn_8x8(255 - out[2], m); out += n; } - } else { // YCbCr + alpha? Ignore the fourth channel for now + } else { // YCbCr + alpha? Ignore the fourth channel for now z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); } } else for (i = 0; i < z->s->img_x; ++i) { out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 + out[3] = 255; // not used if n==3 out += n; } } else { @@ -4192,7 +4242,7 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * out += n; } } else { - stbi_uc * y = coutput[0]; + stbi_uc *y = coutput[0]; if (n == 1) for (i = 0; i < z->s->img_x; ++i) out[i] = y[i]; @@ -4208,14 +4258,14 @@ static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * *out_x = z->s->img_x; *out_y = z->s->img_y; if (comp) - *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output return output; } } -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - unsigned char * result; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); if (!j) return stbi__errpuc("outofmem", "Out of memory"); memset(j, 0, sizeof(stbi__jpeg)); @@ -4227,9 +4277,9 @@ static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, i return result; } -static int stbi__jpeg_test(stbi__context * s) { +static int stbi__jpeg_test(stbi__context *s) { int r; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); if (!j) return stbi__err("outofmem", "Out of memory"); memset(j, 0, sizeof(stbi__jpeg)); @@ -4241,7 +4291,7 @@ static int stbi__jpeg_test(stbi__context * s) { return r; } -static int stbi__jpeg_info_raw(stbi__jpeg * j, int * x, int * y, int * comp) { +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { stbi__rewind(j->s); return 0; @@ -4255,9 +4305,9 @@ static int stbi__jpeg_info_raw(stbi__jpeg * j, int * x, int * y, int * comp) { return 1; } -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { int result; - stbi__jpeg * j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); if (!j) return stbi__err("outofmem", "Out of memory"); memset(j, 0, sizeof(stbi__jpeg)); @@ -4278,9 +4328,9 @@ static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables #define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) -#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet +#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) @@ -4308,7 +4358,7 @@ stbi_inline static int stbi__bit_reverse(int v, int bits) { return stbi__bitreverse16(v) >> (16 - bits); } -static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, int num) { +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) { int i, k = 0; int code, next_code[16], sizes[17]; @@ -4330,11 +4380,11 @@ static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, in if (sizes[i]) if (code - 1 >= (1 << i)) return stbi__err("bad codelengths", "Corrupt PNG"); - z->maxcode[i] = code << (16 - i); // preshift for inner loop + z->maxcode[i] = code << (16 - i); // preshift for inner loop code <<= 1; k += sizes[i]; } - z->maxcode[16] = 0x10000; // sentinel + z->maxcode[16] = 0x10000; // sentinel for (i = 0; i < num; ++i) { int s = sizelist[i]; if (s) { @@ -4366,19 +4416,23 @@ typedef struct { int num_bits; stbi__uint32 code_buffer; - char * zout; - char * zout_start; - char * zout_end; + char *zout; + char *zout_start; + char *zout_end; int z_expandable; stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf * z) { return (z->zbuffer >= z->zbuffer_end); } +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); +} -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf * z) { return stbi__zeof(z) ? 0 : *z->zbuffer++; } +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; +} -static void stbi__fill_bits(stbi__zbuf * z) { +static void stbi__fill_bits(stbi__zbuf *z) { do { if (z->code_buffer >= (1U << z->num_bits)) { z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ @@ -4389,7 +4443,7 @@ static void stbi__fill_bits(stbi__zbuf * z) { } while (z->num_bits <= 24); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf * z, int n) { +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { unsigned int k; if (z->num_bits < n) stbi__fill_bits(z); @@ -4399,7 +4453,7 @@ stbi_inline static unsigned int stbi__zreceive(stbi__zbuf * z, int n) { return k; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf * a, stbi__zhuffman * z) { +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { int b, s, k; // not resolved by fast table, so compute it the slow way // use jpeg approach, which requires MSbits at top @@ -4408,19 +4462,19 @@ static int stbi__zhuffman_decode_slowpath(stbi__zbuf * a, stbi__zhuffman * z) { if (k < z->maxcode[s]) break; if (s >= 16) - return -1; // invalid code! + return -1; // invalid code! // code size is s, so: b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; if (b >= STBI__ZNSYMS) - return -1; // some data was corrupt somewhere! + return -1; // some data was corrupt somewhere! if (z->size[b] != s) - return -1; // was originally an assert, but report failure instead. + return -1; // was originally an assert, but report failure instead. a->code_buffer >>= s; a->num_bits -= s; return z->value[b]; } -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf * a, stbi__zhuffman * z) { +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { int b, s; if (a->num_bits < 16) { if (stbi__zeof(a)) { @@ -4438,9 +4492,9 @@ stbi_inline static int stbi__zhuffman_decode(stbi__zbuf * a, stbi__zhuffman * z) return stbi__zhuffman_decode_slowpath(a, z); } -static int stbi__zexpand(stbi__zbuf * z, char * zout, int n) // need to make room for n bytes +static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes { - char * q; + char *q; unsigned int cur, limit, old_limit; z->zout = zout; if (!z->z_expandable) @@ -4464,26 +4518,26 @@ static int stbi__zexpand(stbi__zbuf * z, char * zout, int n) // need to make roo return 1; } -static const int stbi__zlength_base[31] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, - 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; +static const int stbi__zlength_base[31] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0}; -static const int stbi__zdist_base[32] = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, - 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, - 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; +static const int stbi__zdist_base[32] = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; -static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; -static int stbi__parse_huffman_block(stbi__zbuf * a) { - char * zout = a->zout; +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; for (;;) { int z = stbi__zhuffman_decode(a, &a->z_length); if (z < 256) { if (z < 0) - return stbi__err("bad huffman code", "Corrupt PNG"); // error in huffman codes + return stbi__err("bad huffman code", "Corrupt PNG"); // error in huffman codes if (zout >= a->zout_end) { if (!stbi__zexpand(a, zout, 1)) return 0; @@ -4491,7 +4545,7 @@ static int stbi__parse_huffman_block(stbi__zbuf * a) { } *zout++ = (char)z; } else { - stbi_uc * p; + stbi_uc *p; int len, dist; if (z == 256) { a->zout = zout; @@ -4499,7 +4553,7 @@ static int stbi__parse_huffman_block(stbi__zbuf * a) { } if (z >= 286) return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + "Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data z -= 257; len = stbi__zlength_base[z]; if (stbi__zlength_extra[z]) @@ -4507,7 +4561,7 @@ static int stbi__parse_huffman_block(stbi__zbuf * a) { z = stbi__zhuffman_decode(a, &a->z_distance); if (z < 0 || z >= 30) return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + "Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data dist = stbi__zdist_base[z]; if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); @@ -4519,7 +4573,7 @@ static int stbi__parse_huffman_block(stbi__zbuf * a) { zout = a->zout; } p = (stbi_uc *)(zout - dist); - if (dist == 1) { // run of one byte; common in images. + if (dist == 1) { // run of one byte; common in images. stbi_uc v = *p; if (len) { do @@ -4537,10 +4591,10 @@ static int stbi__parse_huffman_block(stbi__zbuf * a) { } } -static int stbi__compute_huffman_codes(stbi__zbuf * a) { +static int stbi__compute_huffman_codes(stbi__zbuf *a) { static const stbi_uc length_dezigzag[19] = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; stbi__zhuffman z_codelength; - stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op stbi_uc codelength_sizes[19]; int i, n; @@ -4593,15 +4647,15 @@ static int stbi__compute_huffman_codes(stbi__zbuf * a) { return 1; } -static int stbi__parse_uncompressed_block(stbi__zbuf * a) { +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { stbi_uc header[4]; int len, nlen, k; if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard + stbi__zreceive(a, a->num_bits & 7); // discard // drain the bit-packed data into header k = 0; while (a->num_bits > 0) { - header[k++] = (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + header[k++] = (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check a->code_buffer >>= 8; a->num_bits -= 8; } @@ -4625,19 +4679,19 @@ static int stbi__parse_uncompressed_block(stbi__zbuf * a) { return 1; } -static int stbi__parse_zlib_header(stbi__zbuf * a) { +static int stbi__parse_zlib_header(stbi__zbuf *a) { int cmf = stbi__zget8(a); int cm = cmf & 15; /* int cinfo = cmf >> 4; */ int flg = stbi__zget8(a); if (stbi__zeof(a)) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec if ((cmf * 256 + flg) % 31 != 0) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec if (flg & 32) - return stbi__err("no preset dict", "Corrupt PNG"); // preset dictionary not allowed in png + return stbi__err("no preset dict", "Corrupt PNG"); // preset dictionary not allowed in png if (cm != 8) - return stbi__err("bad compression", "Corrupt PNG"); // DEFLATE required for png + return stbi__err("bad compression", "Corrupt PNG"); // DEFLATE required for png // window = 1 << (8 + cinfo)... but who cares, we fully buffer output return 1; } @@ -4666,7 +4720,7 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf * a, int parse_header) { +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { int final, type; if (parse_header) if (!stbi__parse_zlib_header(a)) @@ -4699,7 +4753,7 @@ static int stbi__parse_zlib(stbi__zbuf * a, int parse_header) { return 1; } -static int stbi__do_zlib(stbi__zbuf * a, char * obuf, int olen, int exp, int parse_header) { +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) { a->zout_start = obuf; a->zout = obuf; a->zout_end = obuf + olen; @@ -4708,9 +4762,9 @@ static int stbi__do_zlib(stbi__zbuf * a, char * obuf, int olen, int exp, int par return stbi__parse_zlib(a, parse_header); } -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen) { +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) { stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); + char *p = (char *)stbi__malloc(initial_size); if (p == NULL) return NULL; a.zbuffer = (stbi_uc *)buffer; @@ -4725,14 +4779,14 @@ STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, i } } -STBIDEF char * stbi_zlib_decode_malloc(char const * buffer, int len, int * outlen) { +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) { return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); } -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header) { +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, + int parse_header) { stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); + char *p = (char *)stbi__malloc(initial_size); if (p == NULL) return NULL; a.zbuffer = (stbi_uc *)buffer; @@ -4747,7 +4801,7 @@ STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, } } -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, char const * ibuffer, int ilen) { +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) { stbi__zbuf a; a.zbuffer = (stbi_uc *)ibuffer; a.zbuffer_end = (stbi_uc *)ibuffer + ilen; @@ -4757,9 +4811,9 @@ STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, char const * ibuff return -1; } -STBIDEF char * stbi_zlib_decode_noheader_malloc(char const * buffer, int len, int * outlen) { +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) { stbi__zbuf a; - char * p = (char *)stbi__malloc(16384); + char *p = (char *)stbi__malloc(16384); if (p == NULL) return NULL; a.zbuffer = (stbi_uc *)buffer; @@ -4774,7 +4828,7 @@ STBIDEF char * stbi_zlib_decode_noheader_malloc(char const * buffer, int len, in } } -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen) { +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) { stbi__zbuf a; a.zbuffer = (stbi_uc *)ibuffer; a.zbuffer_end = (stbi_uc *)ibuffer + ilen; @@ -4801,14 +4855,14 @@ typedef struct { stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context * s) { +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { stbi__pngchunk c; c.length = stbi__get32be(s); c.type = stbi__get32be(s); return c; } -static int stbi__check_png_header(stbi__context * s) { +static int stbi__check_png_header(stbi__context *s) { static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; int i; for (i = 0; i < 8; ++i) @@ -4818,7 +4872,7 @@ static int stbi__check_png_header(stbi__context * s) { } typedef struct { - stbi__context * s; + stbi__context *s; stbi_uc *idata, *expanded, *out; int depth; } stbi__png; @@ -4851,21 +4905,21 @@ static int stbi__paeth(int a, int b, int c) { static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) { int bytes = (depth == 16 ? 2 : 1); - stbi__context * s = a->s; + stbi__context *s = a->s; stbi__uint32 i, j, stride = x * out_n * bytes; stbi__uint32 img_len, img_width_bytes; int k; - int img_n = s->img_n; // copy it into a local for later + int img_n = s->img_n; // copy it into a local for later int output_bytes = out_n * bytes; int filter_bytes = img_n * bytes; int width = x; STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); - a->out = (stbi_uc *)stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + a->out = (stbi_uc *)stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into if (!a->out) return stbi__err("outofmem", "Out of memory"); @@ -4881,8 +4935,8 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 return stbi__err("not enough pixels", "Corrupt PNG"); for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * prior; + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; int filter = *raw++; if (filter > 4) @@ -4891,11 +4945,11 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 if (depth < 8) { if (img_width_bytes > x) return stbi__err("invalid width", "Corrupt PNG"); - cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place + cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place filter_bytes = 1; width = img_width_bytes; } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above // if first row, use special filter that doesn't sample previous row if (j == 0) @@ -4930,14 +4984,14 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 if (depth == 8) { if (img_n != out_n) - cur[img_n] = 255; // first pixel + cur[img_n] = 255; // first pixel raw += img_n; cur += out_n; prior += out_n; } else if (depth == 16) { if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes + 1] = 255; // first pixel bottom byte + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } raw += filter_bytes; cur += output_bytes; @@ -4951,27 +5005,37 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 // this is a little gross, so that we don't switch per-pixel or per-component if (depth < 8 || img_n == out_n) { int nk = (width - 1) * filter_bytes; -#define STBI__CASE(f) \ - case f: \ +#define STBI__CASE(f) \ + case f: \ for (k = 0; k < nk; ++k) switch (filter) { // "none" filter turns into a memcpy here; make that explicit. case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); } + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } break; STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); } + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); } + STBI__CASE(STBI__F_paeth_first) { + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } break; } #undef STBI__CASE @@ -4983,21 +5047,33 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, cur += output_bytes, prior += output_bytes) \ for (k = 0; k < filter_bytes; ++k) switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); } + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); + } break; STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); } + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); } + STBI__CASE(STBI__F_paeth_first) { + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } break; } #undef STBI__CASE @@ -5005,7 +5081,7 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 // the loop above sets the high byte of the pixels' alpha, but for // 16 bit png files we also need the low byte set. we'll do that here. if (depth == 16) { - cur = a->out + stride * j; // start at the beginning of the row again + cur = a->out + stride * j; // start at the beginning of the row again for (i = 0; i < x; ++i, cur += output_bytes) { cur[filter_bytes + 1] = 255; } @@ -5018,12 +5094,12 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 // intefere with filtering but will still be in the cache. if (depth < 8) { for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * in = a->out + stride * j + x * out_n - img_width_bytes; + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for // 1/2/4-bit png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that // will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range // note that the final byte might overshoot and write more data than desired. // we can allocate enough data that this never writes out of memory, but it @@ -5102,8 +5178,8 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 // this is done in a separate pass due to the decoding relying // on the data being untouched, but could probably be done // per-line during decode if care is taken. - stbi_uc * cur = a->out; - stbi__uint16 * cur16 = (stbi__uint16 *)cur; + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { *cur16 = (cur[0] << 8) | cur[1]; @@ -5113,11 +5189,11 @@ static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 return 1; } -static int stbi__create_png_image(stbi__png * a, stbi_uc * image_data, stbi__uint32 image_data_len, int out_n, int depth, +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) { int bytes = (depth == 16 ? 2 : 1); int out_bytes = out_n * bytes; - stbi_uc * final; + stbi_uc *final; int p; if (!interlaced) return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); @@ -5159,10 +5235,10 @@ static int stbi__create_png_image(stbi__png * a, stbi_uc * image_data, stbi__uin return 1; } -static int stbi__compute_transparency(stbi__png * z, stbi_uc tc[3], int out_n) { - stbi__context * s = z->s; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; + stbi_uc *p = z->out; // compute color-based transparency, assuming we've // already got 255 as the alpha value in the output @@ -5183,10 +5259,10 @@ static int stbi__compute_transparency(stbi__png * z, stbi_uc tc[3], int out_n) { return 1; } -static int stbi__compute_transparency16(stbi__png * z, stbi__uint16 tc[3], int out_n) { - stbi__context * s = z->s; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) { + stbi__context *s = z->s; stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 * p = (stbi__uint16 *)z->out; + stbi__uint16 *p = (stbi__uint16 *)z->out; // compute color-based transparency, assuming we've // already got 65535 as the alpha value in the output @@ -5207,7 +5283,7 @@ static int stbi__compute_transparency16(stbi__png * z, stbi__uint16 tc[3], int o return 1; } -static int stbi__expand_png_palette(stbi__png * a, stbi_uc * palette, int len, int pal_img_n) { +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) { stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; stbi_uc *p, *temp_out, *orig = a->out; @@ -5272,17 +5348,17 @@ STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_conve stbi__de_iphone_flag_set = 1; } -#define stbi__unpremultiply_on_load \ +#define stbi__unpremultiply_on_load \ (stbi__unpremultiply_on_load_set ? stbi__unpremultiply_on_load_local : stbi__unpremultiply_on_load_global) #define stbi__de_iphone_flag (stbi__de_iphone_flag_set ? stbi__de_iphone_flag_local : stbi__de_iphone_flag_global) -#endif // STBI_THREAD_LOCAL +#endif // STBI_THREAD_LOCAL -static void stbi__de_iphone(stbi__png * z) { - stbi__context * s = z->s; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; + stbi_uc *p = z->out; - if (s->img_out_n == 3) { // convert bgr to rgb + if (s->img_out_n == 3) { // convert bgr to rgb for (i = 0; i < pixel_count; ++i) { stbi_uc t = p[0]; p[0] = p[2]; @@ -5321,13 +5397,13 @@ static void stbi__de_iphone(stbi__png * z) { #define STBI__PNG_TYPE(a, b, c, d) (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + (unsigned)(d)) -static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { stbi_uc palette[1024], pal_img_n = 0; stbi_uc has_trans = 0, tc[3] = {0}; stbi__uint16 tc16[3]; stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; int first = 1, k, interlace = 0, color = 0, is_iphone = 0; - stbi__context * s = z->s; + stbi__context *s = z->s; z->expanded = NULL; z->idata = NULL; @@ -5444,11 +5520,11 @@ static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { } if (z->depth == 16) { for (k = 0; k < s->img_n; ++k) - tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is } else { for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * - stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger } } break; @@ -5471,7 +5547,7 @@ static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { return 0; if (ioff + c.length > idata_limit) { stbi__uint32 idata_limit_old = idata_limit; - stbi_uc * p; + stbi_uc *p; if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; while (ioff + c.length > idata_limit) @@ -5497,12 +5573,12 @@ static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { if (z->idata == NULL) return stbi__err("no IDAT", "Corrupt PNG"); // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag((char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); if (z->expanded == NULL) - return 0; // zlib should set error + return 0; // zlib should set error STBI_FREE(z->idata); z->idata = NULL; if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || has_trans) @@ -5524,7 +5600,7 @@ static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { stbi__de_iphone(z); if (pal_img_n) { // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had + s->img_n = pal_img_n; // record the actual colors we had s->img_out_n = pal_img_n; if (req_comp >= 3) s->img_out_n = req_comp; @@ -5564,8 +5640,8 @@ static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { } } -static void * stbi__do_png(stbi__png * p, int * x, int * y, int * n, int req_comp, stbi__result_info * ri) { - void * result = NULL; +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) { + void *result = NULL; if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { @@ -5601,20 +5677,20 @@ static void * stbi__do_png(stbi__png * p, int * x, int * y, int * n, int req_com return result; } -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { stbi__png p; p.s = s; return stbi__do_png(&p, x, y, comp, req_comp, ri); } -static int stbi__png_test(stbi__context * s) { +static int stbi__png_test(stbi__context *s) { int r; r = stbi__check_png_header(s); stbi__rewind(s); return r; } -static int stbi__png_info_raw(stbi__png * p, int * x, int * y, int * comp) { +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { stbi__rewind(p->s); return 0; @@ -5628,13 +5704,13 @@ static int stbi__png_info_raw(stbi__png * p, int * x, int * y, int * comp) { return 1; } -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { stbi__png p; p.s = s; return stbi__png_info_raw(&p, x, y, comp); } -static int stbi__png_is16(stbi__context * s) { +static int stbi__png_is16(stbi__context *s) { stbi__png p; p.s = s; if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) @@ -5650,23 +5726,23 @@ static int stbi__png_is16(stbi__context * s) { // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context * s) { +static int stbi__bmp_test_raw(stbi__context *s) { int r; int sz; if (stbi__get8(s) != 'B') return 0; if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset sz = stbi__get32le(s); r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); return r; } -static int stbi__bmp_test(stbi__context * s) { +static int stbi__bmp_test(stbi__context *s) { int r = stbi__bmp_test_raw(s); stbi__rewind(s); return r; @@ -5700,11 +5776,11 @@ static int stbi__high_bit(unsigned int z) { } static int stbi__bitcount(unsigned int a) { - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits return a & 0xff; } @@ -5724,7 +5800,15 @@ static int stbi__shiftsigned(unsigned int v, int shift, int bits) { 0x01 /*0b00000001*/, }; static unsigned int shift_table[9] = { - 0, 0, 0, 1, 0, 2, 4, 6, 0, + 0, + 0, + 0, + 1, + 0, + 2, + 4, + 6, + 0, }; if (shift < 0) v <<= -shift; @@ -5742,7 +5826,7 @@ typedef struct { int extra_read; } stbi__bmp_data; -static int stbi__bmp_set_mask_defaults(stbi__bmp_data * info, int compress) { +static int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress) { // BI_BITFIELDS specifies masks explicitly, don't override if (compress == 3) return 1; @@ -5757,23 +5841,23 @@ static int stbi__bmp_set_mask_defaults(stbi__bmp_data * info, int compress) { info->mg = 0xffu << 8; info->mb = 0xffu << 0; info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 } else { // otherwise, use defaults, which is all-0 info->mr = info->mg = info->mb = info->ma = 0; } return 1; } - return 0; // error + return 0; // error } -static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { int hsz; if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved info->offset = stbi__get32le(s); info->hsz = hsz = stbi__get32le(s); info->mr = info->mg = info->mb = info->ma = 0; @@ -5800,14 +5884,14 @@ static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); if (compress >= 4) return stbi__errpuc("BMP JPEG/PNG", - "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes if (compress == 3 && info->bpp != 16 && info->bpp != 32) - return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important + return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important if (hsz == 40 || hsz == 56) { if (hsz == 56) { stbi__get32le(s); @@ -5840,24 +5924,24 @@ static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { info->mg = stbi__get32le(s); info->mb = stbi__get32le(s); info->ma = stbi__get32le(s); - if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs stbi__bmp_set_mask_defaults(info, compress); - stbi__get32le(s); // discard color space + stbi__get32le(s); // discard color space for (i = 0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters + stbi__get32le(s); // discard color space parameters if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } } } return (void *)1; } -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { + stbi_uc *out; unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; stbi_uc pal[256][4]; int psize = 0, i, j, width; @@ -5867,7 +5951,7 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in info.all_a = 255; if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set + return NULL; // error code already set flip_vertically = ((int)s->img_y) > 0; s->img_y = abs((int)s->img_y); @@ -5894,8 +5978,8 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in // accept some number of extra bytes after the header, but if the offset points either to before // the header ends or implies a large amount of extra data, reject the file as malformed int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); - int header_limit = 1024; // max we actually read is below 256 bytes currently. - int extra_data_limit = 256 * 4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256 * 4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { return stbi__errpuc("bad header", "Corrupt BMP"); } @@ -5914,10 +5998,10 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in s->img_n = 3; else s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 target = req_comp; else - target = s->img_n; // if they want monochrome, we'll post-convert + target = s->img_n; // if they want monochrome, we'll post-convert // sanity-check size if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) @@ -6072,8 +6156,8 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in if (flip_vertically) { stbi_uc t; for (j = 0; j < (int)s->img_y >> 1; ++j) { - stbi_uc * p1 = out + j * s->img_x * target; - stbi_uc * p2 = out + (s->img_y - 1 - j) * s->img_x * target; + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; for (i = 0; i < (int)s->img_x * target; ++i) { t = p1[i]; p1[i] = p2[i]; @@ -6085,7 +6169,7 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in if (req_comp && req_comp != target) { out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); if (out == NULL) - return out; // stbi__convert_format frees input on failure + return out; // stbi__convert_format frees input on failure } *x = s->img_x; @@ -6100,7 +6184,7 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { // only RGB or RGBA (incl. 16bit) or grey allowed if (is_rgb16) *is_rgb16 = 0; @@ -6115,7 +6199,7 @@ static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { if (is_rgb16) *is_rgb16 = 1; return STBI_rgb; - case 24: // fallthrough + case 24: // fallthrough case 32: return bits_per_pixel / 8; default: @@ -6123,49 +6207,49 @@ static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { } } -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type if (tga_colormap_type > 1) { stbi__rewind(s); - return 0; // only RGB or indexed allowed + return 0; // only RGB or indexed allowed } - tga_image_type = stbi__get8(s); // image type - if (tga_colormap_type == 1) { // colormapped (paletted) image + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image if (tga_image_type != 1 && tga_image_type != 9) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry + stbi__skip(s, 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip image x and y origin + stbi__skip(s, 4); // skip image x and y origin tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE if ((tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11)) { stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE + return 0; // only RGB or grey allowed, +/- RLE } - stbi__skip(s, 9); // skip colormap specification and image x/y origin + stbi__skip(s, 9); // skip colormap specification and image x/y origin tga_colormap_bpp = 0; } tga_w = stbi__get16le(s); if (tga_w < 1) { stbi__rewind(s); - return 0; // test width + return 0; // test width } tga_h = stbi__get16le(s); if (tga_h < 1) { stbi__rewind(s); - return 0; // test height + return 0; // test height } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits if (tga_colormap_bpp != 0) { if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { // when using a colormap, tga_bits_per_pixel is the size of the indexes @@ -6187,41 +6271,41 @@ static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { *y = tga_h; if (comp) *comp = tga_comp; - return 1; // seems to have passed everything + return 1; // seems to have passed everything } -static int stbi__tga_test(stbi__context * s) { +static int stbi__tga_test(stbi__context *s) { int res = 0; int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type if (tga_color_type > 1) - goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if (tga_color_type == 1) { // colormapped (paletted) image + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image if (sz != 1 && sz != 9) - goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) goto errorEnd; - stbi__skip(s, 4); // skip image x and y origin - } else { // "normal" image w/o colormap + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) - goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s, 9); // skip colormap specification and image x/y origin + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin } if (stbi__get16le(s) < 1) - goto errorEnd; // test width + goto errorEnd; // test width if (stbi__get16le(s) < 1) - goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) - goto errorEnd; // for colormapped images, bpp is size of an index + goto errorEnd; // for colormapped images, bpp is size of an index if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) goto errorEnd; - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 errorEnd: stbi__rewind(s); @@ -6229,7 +6313,7 @@ static int stbi__tga_test(stbi__context * s) { } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context * s, stbi_uc * out) { +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { stbi__uint16 px = (stbi__uint16)stbi__get16le(s); stbi__uint16 fiveBitMask = 31; // we have 3 channels with 5bits each @@ -6247,7 +6331,7 @@ static void stbi__tga_read_rgb16(stbi__context * s, stbi_uc * out) { // so let's treat all 15 and 16bit TGAs as RGB with no alpha. } -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { // read in the TGA header stuff int tga_offset = stbi__get8(s); int tga_indexed = stbi__get8(s); @@ -6265,16 +6349,16 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in int tga_inverted = stbi__get8(s); // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) // image data - unsigned char * tga_data; - unsigned char * tga_palette = NULL; + unsigned char *tga_data; + unsigned char *tga_palette = NULL; int i, j; unsigned char raw_data[4] = {0}; int RLE_count = 0; int RLE_repeating = 0; int read_next_pixel = 1; STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large", "Very large image (corrupt?)"); @@ -6294,7 +6378,7 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); // tga info @@ -6316,7 +6400,7 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { for (i = 0; i < tga_height; ++i) { int row = tga_inverted ? tga_height - i - 1 : i; - stbi_uc * tga_row = tga_data + row * tga_width * tga_comp; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; stbi__getn(s, tga_row, tga_width * tga_comp); } } else { @@ -6336,7 +6420,7 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in return stbi__errpuc("outofmem", "Out of memory"); } if (tga_rgb16) { - stbi_uc * pal_entry = tga_palette; + stbi_uc *pal_entry = tga_palette; STBI_ASSERT(tga_comp == STBI_rgb); for (i = 0; i < tga_palette_len; ++i) { stbi__tga_read_rgb16(s, pal_entry); @@ -6389,7 +6473,7 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in } // clear the reading flag for the next pixel read_next_pixel = 0; - } // end of reading a pixel + } // end of reading a pixel // copy data for (j = 0; j < tga_comp; ++j) @@ -6420,7 +6504,7 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in // swap RGB - if the source data was RGB16, it already is in the right order if (tga_comp >= 3 && !tga_rgb16) { - unsigned char * tga_pixel = tga_data; + unsigned char *tga_pixel = tga_data; for (i = 0; i < tga_width * tga_height; ++i) { unsigned char temp = tga_pixel[0]; tga_pixel[0] = tga_pixel[2]; @@ -6446,13 +6530,13 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in // Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s) { +static int stbi__psd_test(stbi__context *s) { int r = (stbi__get32be(s) == 0x38425053); stbi__rewind(s); return r; } -static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) { +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { int count, nleft, len; count = 0; @@ -6464,7 +6548,7 @@ static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) // Copy next len+1 bytes literally. len++; if (len > nleft) - return 0; // corrupt data + return 0; // corrupt data count += len; while (len) { *p = stbi__get8(s); @@ -6477,7 +6561,7 @@ static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) // (Interpret len as a negative 8-bit int.) len = 257 - len; if (len > nleft) - return 0; // corrupt data + return 0; // corrupt data val = stbi__get8(s); count += len; while (len) { @@ -6491,17 +6575,17 @@ static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) return 1; } -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) { int pixelCount; int channelCount, compression; int channel, i; int bitdepth; int w, h; - stbi_uc * out; + stbi_uc *out; STBI_NOTUSED(ri); // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" + if (stbi__get32be(s) != 0x38425053) // "8BPS" return stbi__errpuc("not PSD", "Corrupt PSD image"); // Check file type version. @@ -6595,7 +6679,7 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in // Read the RLE data by channel. for (channel = 0; channel < 4; channel++) { - stbi_uc * p; + stbi_uc *p; p = out + channel; if (channel >= channelCount) { @@ -6619,24 +6703,24 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in if (channel >= channelCount) { // Fill this channel with default data. if (bitdepth == 16 && bpc == 16) { - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; stbi__uint16 val = channel == 3 ? 65535 : 0; for (i = 0; i < pixelCount; i++, q += 4) *q = val; } else { - stbi_uc * p = out + channel; + stbi_uc *p = out + channel; stbi_uc val = channel == 3 ? 255 : 0; for (i = 0; i < pixelCount; i++, p += 4) *p = val; } } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; for (i = 0; i < pixelCount; i++, q += 4) *q = (stbi__uint16)stbi__get16be(s); } else { - stbi_uc * p = out + channel; - if (bitdepth == 16) { // input bpc + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) *p = (stbi_uc)(stbi__get16be(s) >> 8); } else { @@ -6652,7 +6736,7 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in if (channelCount >= 4) { if (ri->bits_per_channel == 16) { for (i = 0; i < w * h; ++i) { - stbi__uint16 * pixel = (stbi__uint16 *)out + 4 * i; + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; if (pixel[3] != 0 && pixel[3] != 65535) { float a = pixel[3] / 65535.0f; float ra = 1.0f / a; @@ -6664,7 +6748,7 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in } } else { for (i = 0; i < w * h; ++i) { - unsigned char * pixel = out + 4 * i; + unsigned char *pixel = out + 4 * i; if (pixel[3] != 0 && pixel[3] != 255) { float a = pixel[3] / 255.0f; float ra = 1.0f / a; @@ -6684,7 +6768,7 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in else out = stbi__convert_format(out, 4, req_comp, w, h); if (out == NULL) - return out; // stbi__convert_format frees input on failure + return out; // stbi__convert_format frees input on failure } if (comp) @@ -6704,7 +6788,7 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context * s, const char * str) { +static int stbi__pic_is4(stbi__context *s, const char *str) { int i; for (i = 0; i < 4; ++i) if (stbi__get8(s) != (stbi_uc)str[i]) @@ -6713,7 +6797,7 @@ static int stbi__pic_is4(stbi__context * s, const char * str) { return 1; } -static int stbi__pic_test_core(stbi__context * s) { +static int stbi__pic_test_core(stbi__context *s) { int i; if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) @@ -6732,7 +6816,7 @@ typedef struct { stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc * stbi__readval(stbi__context * s, int channel, stbi_uc * dest) { +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { int mask = 0x80, i; for (i = 0; i < 4; ++i, mask >>= 1) { @@ -6746,7 +6830,7 @@ static stbi_uc * stbi__readval(stbi__context * s, int channel, stbi_uc * dest) { return dest; } -static void stbi__copyval(int channel, stbi_uc * dest, const stbi_uc * src) { +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { int mask = 0x80, i; for (i = 0; i < 4; ++i, mask >>= 1) @@ -6754,14 +6838,14 @@ static void stbi__copyval(int channel, stbi_uc * dest, const stbi_uc * src) { dest[i] = src[i]; } -static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, int * comp, stbi_uc * result) { +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, int *comp, stbi_uc *result) { int act_comp = 0, num_packets = 0, y, chained; stbi__pic_packet packets[10]; // this will (should...) cater for even some bizarre stuff like having data // for the same channel in multiple packets. do { - stbi__pic_packet * packet; + stbi__pic_packet *packet; if (num_packets == sizeof(packets) / sizeof(packets[0])) return stbi__errpuc("bad format", "too many packets"); @@ -6781,20 +6865,20 @@ static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, i return stbi__errpuc("bad format", "packet isn't 8bpp"); } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? for (y = 0; y < height; ++y) { int packet_idx; for (packet_idx = 0; packet_idx < num_packets; ++packet_idx) { - stbi__pic_packet * packet = &packets[packet_idx]; - stbi_uc * dest = result + y * width * 4; + stbi__pic_packet *packet = &packets[packet_idx]; + stbi_uc *dest = result + y * width * 4; switch (packet->type) { default: return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: { // uncompressed + case 0: { // uncompressed int x; for (x = 0; x < width; ++x, dest += 4) @@ -6803,7 +6887,7 @@ static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, i break; } - case 1: // Pure RLE + case 1: // Pure RLE { int left = width, i; @@ -6826,14 +6910,14 @@ static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, i } } break; - case 2: { // Mixed RLE + case 2: { // Mixed RLE int left = width; while (left > 0) { int count = stbi__get8(s), i; if (stbi__at_eof(s)) return stbi__errpuc("bad file", "file too short (mixed read count)"); - if (count >= 128) { // Repeated + if (count >= 128) { // Repeated stbi_uc value[4]; if (count == 128) @@ -6848,7 +6932,7 @@ static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, i for (i = 0; i < count; ++i, dest += 4) stbi__copyval(packet->channel, dest, value); - } else { // Raw + } else { // Raw ++count; if (count > left) return stbi__errpuc("bad file", "scanline overrun"); @@ -6868,8 +6952,8 @@ static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, i return result; } -static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * result; +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, int req_comp, stbi__result_info *ri) { + stbi_uc *result; int i, x, y, internal_comp; STBI_NOTUSED(ri); @@ -6892,9 +6976,9 @@ static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); // skip `ratio' - stbi__get16be(s); // skip `fields' - stbi__get16be(s); // skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' // intermediate buffer is RGBA result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); @@ -6915,7 +6999,7 @@ static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, return result; } -static int stbi__pic_test(stbi__context * s) { +static int stbi__pic_test(stbi__context *s) { int r = stbi__pic_test_core(s); stbi__rewind(s); return r; @@ -6934,14 +7018,14 @@ typedef struct { typedef struct { int w, h; - stbi_uc * out; // output buffer (always 4 components) - stbi_uc * background; // The current "background" as far as a gif is concerned - stbi_uc * history; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; int flags, bgindex, ratio, transparent, eflags; stbi_uc pal[256][4]; stbi_uc lpal[256][4]; stbi__gif_lzw codes[8192]; - stbi_uc * color_table; + stbi_uc *color_table; int parse, step; int lflags; int start_x, start_y; @@ -6951,7 +7035,7 @@ typedef struct { int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context * s) { +static int stbi__gif_test_raw(stbi__context *s) { int sz; if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; @@ -6963,13 +7047,13 @@ static int stbi__gif_test_raw(stbi__context * s) { return 1; } -static int stbi__gif_test(stbi__context * s) { +static int stbi__gif_test(stbi__context *s) { int r = stbi__gif_test_raw(s); stbi__rewind(s); return r; } -static void stbi__gif_parse_colortable(stbi__context * s, stbi_uc pal[256][4], int num_entries, int transp) { +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) { int i; for (i = 0; i < num_entries; ++i) { pal[i][2] = stbi__get8(s); @@ -6979,7 +7063,7 @@ static void stbi__gif_parse_colortable(stbi__context * s, stbi_uc pal[256][4], i } } -static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is_info) { +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) { stbi_uc version; if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return stbi__err("not GIF", "Corrupt GIF"); @@ -7004,7 +7088,7 @@ static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is return stbi__err("too large", "Very large image (corrupt?)"); if (comp != 0) - *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments if (is_info) return 1; @@ -7015,8 +7099,8 @@ static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is return 1; } -static int stbi__gif_info_raw(stbi__context * s, int * x, int * y, int * comp) { - stbi__gif * g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); if (!g) return stbi__err("outofmem", "Out of memory"); if (!stbi__gif_header(s, g, comp, 1)) { @@ -7032,7 +7116,7 @@ static int stbi__gif_info_raw(stbi__context * s, int * x, int * y, int * comp) { return 1; } -static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { stbi_uc *p, *c; int idx; @@ -7049,7 +7133,7 @@ static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { g->history[idx / 4] = 1; c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; + if (c[3] > 128) { // don't render transparent pixels; p[0] = c[2]; p[1] = c[1]; p[2] = c[0]; @@ -7069,12 +7153,12 @@ static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { } } -static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { stbi_uc lzw_cs; stbi__int32 len, init_code; stbi__uint32 first; stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw * p; + stbi__gif_lzw *p; lzw_cs = stbi__get8(s); if (lzw_cs > 12) @@ -7099,7 +7183,7 @@ static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { for (;;) { if (valid_bits < codesize) { if (len == 0) { - len = stbi__get8(s); // start new block + len = stbi__get8(s); // start new block if (len == 0) return g->out; } @@ -7111,13 +7195,13 @@ static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { bits >>= codesize; valid_bits -= codesize; // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code + if (code == clear) { // clear code codesize = lzw_cs + 1; codemask = (1 << codesize) - 1; avail = clear + 2; oldcode = -1; first = 0; - } else if (code == clear + 1) { // end of stream code + } else if (code == clear + 1) { // end of stream code stbi__skip(s, len); while ((len = stbi__get8(s)) > 0) stbi__skip(s, len); @@ -7156,7 +7240,7 @@ static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { // this function is designed to support animated gifs, although stb_image doesn't support it // two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * comp, int req_comp, stbi_uc * two_back) { +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) { int dispose; int first_frame; int pi; @@ -7167,7 +7251,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com first_frame = 0; if (g->out == 0) { if (!stbi__gif_header(s, g, comp, 0)) - return 0; // stbi__g_failure_reason set by stbi__gif_header + return 0; // stbi__g_failure_reason set by stbi__gif_header if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) return stbi__errpuc("too large", "GIF image is too large"); pcount = g->w * g->h; @@ -7181,8 +7265,8 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com // background colour is only used for pixels that are not rendered first frame, after that "background" // color refers to the color that was there the previous frame. memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame first_frame = 1; } else { // second frame - how do we dispose of the previous one? @@ -7190,10 +7274,10 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com pcount = g->w * g->h; if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + dispose = 2; // if I don't have an image to revert back to, default to the old background } - if (dispose == 3) { // use previous graphic + if (dispose == 3) { // use previous graphic for (pi = 0; pi < pcount; ++pi) { if (g->history[pi]) { memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); @@ -7218,7 +7302,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com } // clear my history; - memset(g->history, 0x00, g->w * g->h); // pixels that were affected previous frame + memset(g->history, 0x00, g->w * g->h); // pixels that were affected previous frame for (;;) { int tag = stbi__get8(s); @@ -7226,7 +7310,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com case 0x2C: /* Image Descriptor */ { stbi__int32 x, y, w, h; - stbi_uc * o; + stbi_uc *o; x = stbi__get16le(s); y = stbi__get16le(s); @@ -7253,7 +7337,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com g->lflags = stbi__get8(s); if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing + g->step = 8 * g->line_size; // first interlaced spacing g->parse = 3; } else { g->step = g->line_size; @@ -7279,7 +7363,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com for (pi = 0; pi < pcount; ++pi) { if (g->history[pi] == 0) { g->pal[g->bgindex][3] = - 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); } } @@ -7288,15 +7372,15 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com return o; } - case 0x21: // Comment Extension. + case 0x21: // Comment Extension. { int len; int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. + if (ext == 0xF9) { // Graphic Control Extension. len = stbi__get8(s); if (len == 4) { g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. // unset old transparent if (g->transparent >= 0) { @@ -7323,8 +7407,8 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com break; } - case 0x3B: // gif stream termination code - return (stbi_uc *)s; // using '1' causes warning on some compilers + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers default: return stbi__errpuc("unknown code", "Corrupt GIF"); @@ -7332,7 +7416,7 @@ static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * com } } -static void * stbi__load_gif_main_outofmem(stbi__gif * g, stbi_uc * out, int ** delays) { +static void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays) { STBI_FREE(g->out); STBI_FREE(g->history); STBI_FREE(g->background); @@ -7344,12 +7428,12 @@ static void * stbi__load_gif_main_outofmem(stbi__gif * g, stbi_uc * out, int ** return stbi__errpuc("outofmem", "Out of memory"); } -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp) { +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) { if (stbi__gif_test(s)) { int layers = 0; - stbi_uc * u = 0; - stbi_uc * out = 0; - stbi_uc * two_back = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; stbi__gif g; int stride; int out_size = 0; @@ -7366,7 +7450,7 @@ static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int do { u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker + u = 0; // end of animated gif marker if (u) { *x = g.w; @@ -7375,7 +7459,7 @@ static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int stride = g.w * g.h * 4; if (out) { - void * tmp = (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + void *tmp = (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); if (!tmp) return stbi__load_gif_main_outofmem(&g, out, delays); else { @@ -7384,7 +7468,7 @@ static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int } if (delays) { - int * new_delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, sizeof(int) * layers); + int *new_delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, sizeof(int) * layers); if (!new_delays) return stbi__load_gif_main_outofmem(&g, out, delays); *delays = new_delays; @@ -7429,15 +7513,15 @@ static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int } } -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * u = 0; +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; stbi__gif g; memset(&g, 0, sizeof(g)); STBI_NOTUSED(ri); u = stbi__gif_load_next(s, &g, comp, req_comp, 0); if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker + u = 0; // end of animated gif marker if (u) { *x = g.w; *y = g.h; @@ -7458,14 +7542,16 @@ static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, in return u; } -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp) { return stbi__gif_info_raw(s, x, y, comp); } +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); +} #endif // ************************************************************************************************* // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context * s, const char * signature) { +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { int i; for (i = 0; signature[i]; ++i) if (stbi__get8(s) != signature[i]) @@ -7474,7 +7560,7 @@ static int stbi__hdr_test_core(stbi__context * s, const char * signature) { return 1; } -static int stbi__hdr_test(stbi__context * s) { +static int stbi__hdr_test(stbi__context *s) { int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); stbi__rewind(s); if (!r) { @@ -7485,7 +7571,7 @@ static int stbi__hdr_test(stbi__context * s) { } #define STBI__HDR_BUFLEN 1024 -static char * stbi__hdr_gettoken(stbi__context * z, char * buffer) { +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { int len = 0; char c = '\0'; @@ -7506,7 +7592,7 @@ static char * stbi__hdr_gettoken(stbi__context * z, char * buffer) { return buffer; } -static void stbi__hdr_convert(float * output, stbi_uc * input, int req_comp) { +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { if (input[3] != 0) { float f1; // Exponent @@ -7538,17 +7624,17 @@ static void stbi__hdr_convert(float * output, stbi_uc * input, int req_comp) { } } -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { char buffer[STBI__HDR_BUFLEN]; - char * token; + char *token; int valid = 0; int width, height; - stbi_uc * scanline; - float * hdr_data; + stbi_uc *scanline; + float *hdr_data; int len; unsigned char count, value; int i, j, k, c1, c2, z; - const char * headerToken; + const char *headerToken; STBI_NOTUSED(ri); // Check identifier @@ -7635,7 +7721,7 @@ static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, i i = 1; j = 0; STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense + goto main_decode_loop; // yes, this makes no sense } len <<= 8; len |= stbi__get8(s); @@ -7690,9 +7776,9 @@ static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, i return hdr_data; } -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { char buffer[STBI__HDR_BUFLEN]; - char * token; + char *token; int valid = 0; int dummy; @@ -7738,11 +7824,11 @@ static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp) { *comp = 3; return 1; } -#endif // STBI_NO_HDR +#endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp) { - void * p; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; stbi__bmp_data info; info.all_a = 255; @@ -7766,7 +7852,7 @@ static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp) { #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { int channelCount, dummy, depth; if (!x) x = &dummy; @@ -7803,7 +7889,7 @@ static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp) { return 1; } -static int stbi__psd_is16(stbi__context * s) { +static int stbi__psd_is16(stbi__context *s) { int channelCount, depth; if (stbi__get32be(s) != 0x38425053) { stbi__rewind(s); @@ -7831,7 +7917,7 @@ static int stbi__psd_is16(stbi__context * s) { #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { int act_comp = 0, num_packets = 0, chained, dummy; stbi__pic_packet packets[10]; @@ -7863,7 +7949,7 @@ static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { stbi__skip(s, 8); do { - stbi__pic_packet * packet; + stbi__pic_packet *packet; if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; @@ -7904,7 +7990,7 @@ static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s) { +static int stbi__pnm_test(stbi__context *s) { char p, t; p = (char)stbi__get8(s); t = (char)stbi__get8(s); @@ -7915,8 +8001,8 @@ static int stbi__pnm_test(stbi__context * s) { return 1; } -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) { + stbi_uc *out; STBI_NOTUSED(ri); ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); @@ -7951,14 +8037,16 @@ static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, in out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); } if (out == NULL) - return out; // stbi__convert_format frees input on failure + return out; // stbi__convert_format frees input on failure } return out; } -static int stbi__pnm_isspace(char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +} -static void stbi__pnm_skip_whitespace(stbi__context * s, char * c) { +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { for (;;) { while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) *c = (char)stbi__get8(s); @@ -7971,9 +8059,11 @@ static void stbi__pnm_skip_whitespace(stbi__context * s, char * c) { } } -static int stbi__pnm_isdigit(char c) { return c >= '0' && c <= '9'; } +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; +} -static int stbi__pnm_getinteger(stbi__context * s, char * c) { +static int stbi__pnm_getinteger(stbi__context *s, char *c) { int value = 0; while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { @@ -7986,7 +8076,7 @@ static int stbi__pnm_getinteger(stbi__context * s, char * c) { return value; } -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { int maxv, dummy; char c, p, t; @@ -8007,22 +8097,22 @@ static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { return 0; } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm c = (char)stbi__get8(s); stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width + *x = stbi__pnm_getinteger(s, &c); // read width if (*x == 0) return stbi__err("invalid width", "PPM image header had zero or overflowing width"); stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height + *y = stbi__pnm_getinteger(s, &c); // read height if (*y == 0) return stbi__err("invalid width", "PPM image header had zero or overflowing width"); stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value if (maxv > 65535) return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); else if (maxv > 255) @@ -8031,14 +8121,14 @@ static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { return 8; } -static int stbi__pnm_is16(stbi__context * s) { +static int stbi__pnm_is16(stbi__context *s) { if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) return 1; return 0; } #endif -static int stbi__info_main(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { #ifndef STBI_NO_JPEG if (stbi__jpeg_info(s, x, y, comp)) return 1; @@ -8087,7 +8177,7 @@ static int stbi__info_main(stbi__context * s, int * x, int * y, int * comp) { return stbi__err("unknown image type", "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context * s) { +static int stbi__is_16_main(stbi__context *s) { #ifndef STBI_NO_PNG if (stbi__png_is16(s)) return 1; @@ -8106,8 +8196,8 @@ static int stbi__is_16_main(stbi__context * s) { } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); int result; if (!f) return stbi__err("can't fopen", "Unable to open file"); @@ -8116,7 +8206,7 @@ STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp) { return result; } -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp) { +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { int r; stbi__context s; long pos = ftell(f); @@ -8126,8 +8216,8 @@ STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp) { return r; } -STBIDEF int stbi_is_16_bit(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); int result; if (!f) return stbi__err("can't fopen", "Unable to open file"); @@ -8136,7 +8226,7 @@ STBIDEF int stbi_is_16_bit(char const * filename) { return result; } -STBIDEF int stbi_is_16_bit_from_file(FILE * f) { +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { int r; stbi__context s; long pos = ftell(f); @@ -8145,33 +8235,33 @@ STBIDEF int stbi_is_16_bit_from_file(FILE * f) { fseek(f, pos, SEEK_SET); return r; } -#endif // !STBI_NO_STDIO +#endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp) { +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) { stbi__context s; stbi__start_mem(&s, buffer, len); return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * c, void * user, int * x, int * y, int * comp) { +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) { stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len) { +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { stbi__context s; stbi__start_mem(&s, buffer, len); return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * user) { +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) { stbi__context s; stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); return stbi__is_16_main(&s); } -#endif // STB_IMAGE_IMPLEMENTATION +#endif // STB_IMAGE_IMPLEMENTATION /* revision history: diff --git a/src/bb/opencv/bb.h b/src/bb/opencv/bb.h index 9059c316..718d0ec0 100644 --- a/src/bb/opencv/bb.h +++ b/src/bb/opencv/bb.h @@ -58,4 +58,4 @@ class Display : public ion::BuildingBlock { }; ION_REGISTER_BUILDING_BLOCK(Display, opencv_display); -#endif // ION_BB_OPENCV_BB_H +#endif // ION_BB_OPENCV_BB_H diff --git a/src/bb/opencv/rt.h b/src/bb/opencv/rt.h index 18294b8e..8eab3714 100644 --- a/src/bb/opencv/rt.h +++ b/src/bb/opencv/rt.h @@ -16,28 +16,26 @@ namespace opencv { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; - -} // image_io -} // bb -} // ion +} // namespace opencv +} // namespace bb +} // namespace ion #define ION_REGISTER_EXTERN(NAME) static auto ion_register_extern_##NAME = ion::bb::opencv::RegisterExtern(#NAME, NAME); -extern "C" ION_EXPORT -int ion_bb_opencv_median_blur(halide_buffer_t *in, int ksize, halide_buffer_t *out) { - auto& cv(ion::bb::OpenCV::get_instance()); +extern "C" ION_EXPORT int ion_bb_opencv_median_blur(halide_buffer_t *in, int ksize, halide_buffer_t *out) { + auto &cv(ion::bb::OpenCV::get_instance()); if (!cv.is_available()) { ion::log::error("OpenCV is not available"); return -1; } if (in->is_bounds_query()) { - for (auto i=0; idimensions; ++i) { + for (auto i = 0; i < in->dimensions; ++i) { in->dim[i].min = out->dim[i].min; in->dim[i].extent = out->dim[i].extent; } @@ -50,10 +48,10 @@ int ion_bb_opencv_median_blur(halide_buffer_t *in, int ksize, halide_buffer_t *o } auto src = cv.cvCreateMatHeader(height, width, cv_type); - cv.cvSetData(src, in->host, 3*width*sizeof(uint8_t)); + cv.cvSetData(src, in->host, 3 * width * sizeof(uint8_t)); auto dst = cv.cvCreateMatHeader(height, width, cv_type); - cv.cvSetData(dst, out->host, 3*width*sizeof(uint8_t)); + cv.cvSetData(dst, out->host, 3 * width * sizeof(uint8_t)); cv.cvSmooth(src, dst, CV_MEDIAN, ksize, ksize, 0, 0); @@ -65,9 +63,8 @@ int ion_bb_opencv_median_blur(halide_buffer_t *in, int ksize, halide_buffer_t *o } ION_REGISTER_EXTERN(ion_bb_opencv_median_blur); -extern "C" ION_EXPORT -int ion_bb_opencv_display(halide_buffer_t *in, int width, int height, int idx, halide_buffer_t *out) { - auto& cv(ion::bb::OpenCV::get_instance()); +extern "C" ION_EXPORT int ion_bb_opencv_display(halide_buffer_t *in, int width, int height, int idx, halide_buffer_t *out) { + auto &cv(ion::bb::OpenCV::get_instance()); if (!cv.is_available()) { ion::log::error("OpenCV is not available"); return -1; @@ -75,14 +72,14 @@ int ion_bb_opencv_display(halide_buffer_t *in, int width, int height, int idx, h if (in->is_bounds_query()) { in->dim[0].min = 0; - in->dim[0].extent = 3; // RGB + in->dim[0].extent = 3; // RGB in->dim[1].min = 0; in->dim[1].extent = width; in->dim[2].min = 0; in->dim[2].extent = height; } else { auto img = cv.cvCreateMatHeader(height, width, CV_MAKETYPE(CV_8U, 3)); - cv.cvSetData(img, in->host, 3*width*sizeof(uint8_t)); + cv.cvSetData(img, in->host, 3 * width * sizeof(uint8_t)); auto name = "img" + std::to_string(idx); cv.cvShowImage(name.c_str(), img); @@ -97,4 +94,4 @@ ION_REGISTER_EXTERN(ion_bb_opencv_display); #undef ION_REGISTER_EXTERN -#endif // ION_BB_OPENCV_RT_H +#endif // ION_BB_OPENCV_RT_H diff --git a/src/bb/opencv_loader.h b/src/bb/opencv_loader.h index a5cfb307..9030c1a0 100644 --- a/src/bb/opencv_loader.h +++ b/src/bb/opencv_loader.h @@ -5,49 +5,47 @@ namespace { -#define CV_CN_MAX 512 -#define CV_CN_SHIFT 3 -#define CV_DEPTH_MAX (1 << CV_CN_SHIFT) - -#define CV_8U 0 -#define CV_8S 1 -#define CV_16U 2 -#define CV_16S 3 -#define CV_32S 4 -#define CV_32F 5 -#define CV_64F 6 -#define CV_16F 7 - -#define CV_MAT_DEPTH_MASK (CV_DEPTH_MAX - 1) -#define CV_MAT_DEPTH(flags) ((flags) & CV_MAT_DEPTH_MASK) - -#define CV_MAKETYPE(depth,cn) (CV_MAT_DEPTH(depth) + (((cn)-1) << CV_CN_SHIFT)) +#define CV_CN_MAX 512 +#define CV_CN_SHIFT 3 +#define CV_DEPTH_MAX (1 << CV_CN_SHIFT) + +#define CV_8U 0 +#define CV_8S 1 +#define CV_16U 2 +#define CV_16S 3 +#define CV_32S 4 +#define CV_32F 5 +#define CV_64F 6 +#define CV_16F 7 + +#define CV_MAT_DEPTH_MASK (CV_DEPTH_MAX - 1) +#define CV_MAT_DEPTH(flags) ((flags)&CV_MAT_DEPTH_MASK) + +#define CV_MAKETYPE(depth, cn) (CV_MAT_DEPTH(depth) + (((cn)-1) << CV_CN_SHIFT)) #define CV_MAKE_TYPE CV_MAKETYPE #define IMREAD_GRAYSCALE 0 #define IMREAD_COLOR 1 - -enum SmoothMethod_c -{ +enum SmoothMethod_c { /** linear convolution with \f$\texttt{size1}\times\texttt{size2}\f$ box kernel (all 1's). If you want to smooth different pixels with different-size box kernels, you can use the integral image that is computed using integral */ - CV_BLUR_NO_SCALE =0, + CV_BLUR_NO_SCALE = 0, /** linear convolution with \f$\texttt{size1}\times\texttt{size2}\f$ box kernel (all 1's) with subsequent scaling by \f$1/(\texttt{size1}\cdot\texttt{size2})\f$ */ - CV_BLUR =1, + CV_BLUR = 1, /** linear convolution with a \f$\texttt{size1}\times\texttt{size2}\f$ Gaussian kernel */ - CV_GAUSSIAN =2, + CV_GAUSSIAN = 2, /** median filter with a \f$\texttt{size1}\times\texttt{size1}\f$ square aperture */ - CV_MEDIAN =3, + CV_MEDIAN = 3, /** bilateral filter with a \f$\texttt{size1}\times\texttt{size1}\f$ square aperture, color sigma= sigma1 and spatial sigma= sigma2. If size1=0, the aperture square side is set to cvRound(sigma2\*1.5)\*2+1. See cv::bilateralFilter */ - CV_BILATERAL =4 + CV_BILATERAL = 4 }; -} // anonymous +} // namespace namespace ion { namespace bb { @@ -57,54 +55,53 @@ class OpenCV { using CvArr = void; using CvMat = struct CvMat; // To fo, remove c api function later, since it is deprecated - using cvCreateMatHeader_t = CvMat*(*)(int rows, int cols, int type); - using cvReleaseMat_t = void(*)(CvMat **mat); - using cvSetData_t = void(*)(CvArr* arr, void* data, int step); - using cvSplit_t = void(*)(const CvArr *src, CvArr *dst0, CvArr *dst1, CvArr *dst2, CvArr *dst3); - using cvPow_t = void(*) (const CvArr *src, CvArr *dst, double power); - using cvRepeat_t = void(*) (const CvArr *src, CvArr *dst); - using cvSmooth_t = void (*)(const CvArr* src, CvArr* dst, int smoothtype, int size1, int size2, double sigma1, double sigma2); - using cvShowImage_t = void(*)(const char* name, const CvArr* image); - using cvCvtColor_t = void(*)(const CvArr *src, CvArr *dst, int code); -// using cvSaveImage_t = void(*)(const char *filename, const CvArr *image, const int *params); -// using cvLoadImageM_t =CvMat*(*)(const char* filename, int iscolor); - using cvWaitKey_t = int(*)(int delay); + using cvCreateMatHeader_t = CvMat *(*)(int rows, int cols, int type); + using cvReleaseMat_t = void (*)(CvMat **mat); + using cvSetData_t = void (*)(CvArr *arr, void *data, int step); + using cvSplit_t = void (*)(const CvArr *src, CvArr *dst0, CvArr *dst1, CvArr *dst2, CvArr *dst3); + using cvPow_t = void (*)(const CvArr *src, CvArr *dst, double power); + using cvRepeat_t = void (*)(const CvArr *src, CvArr *dst); + using cvSmooth_t = void (*)(const CvArr *src, CvArr *dst, int smoothtype, int size1, int size2, double sigma1, double sigma2); + using cvShowImage_t = void (*)(const char *name, const CvArr *image); + using cvCvtColor_t = void (*)(const CvArr *src, CvArr *dst, int code); + // using cvSaveImage_t = void(*)(const char *filename, const CvArr *image, const int *params); + // using cvLoadImageM_t =CvMat*(*)(const char* filename, int iscolor); + using cvWaitKey_t = int (*)(int delay); using cvResize_t = void (*)(const CvArr *src, CvArr *dst, int interpolation); - using cvNormalize_t = void (*)(const CvArr* src, CvArr* dst, double alpha, double beta, int normtype, int dtype, const CvArr* mask ); - - public: + using cvNormalize_t = void (*)(const CvArr *src, CvArr *dst, double alpha, double beta, int normtype, int dtype, const CvArr *mask); +public: static OpenCV &get_instance() { static OpenCV instance; return instance; } - OpenCV() : + OpenCV() + : #ifdef _WIN32 - // TODO: Determine OpenCV version dynamically - opencv_world_("opencv_world455", false) + // TODO: Determine OpenCV version dynamically + opencv_world_("opencv_world455", false) #else - opencv_core_("opencv_core", false), - opencv_imgproc_("opencv_imgproc", false), - opencv_highgui_("opencv_highgui", false) + opencv_core_("opencv_core", false), + opencv_imgproc_("opencv_imgproc", false), + opencv_highgui_("opencv_highgui", false) #endif { if (is_available()) { init_symbols(); } else { - } } void init_symbols() { #ifdef _WIN32 - #define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ - LOCAL_VAR = opencv_world_.get_symbol(TARGET_SYMBOL); \ - if (LOCAL_VAR == nullptr) { \ - throw ::std::runtime_error( \ - TARGET_SYMBOL " is unavailable on opencv_world"); \ - } +#define GET_SYMBOL(LOCAL_VAR, TARGET_SYMBOL) \ + LOCAL_VAR = opencv_world_.get_symbol(TARGET_SYMBOL); \ + if (LOCAL_VAR == nullptr) { \ + throw ::std::runtime_error( \ + TARGET_SYMBOL " is unavailable on opencv_world"); \ + } GET_SYMBOL(cvCreateMatHeader, "cvCreateMatHeader"); GET_SYMBOL(cvReleaseMat, "cvReleaseMat"); @@ -118,14 +115,14 @@ class OpenCV { GET_SYMBOL(cvShowImage, "cvShowImage"); GET_SYMBOL(cvWaitKey, "cvWaitKey"); - #undef GET_SYMBOL +#undef GET_SYMBOL #else - #define GET_SYMBOL(MODULE, LOCAL_VAR, TARGET_SYMBOL) \ - LOCAL_VAR = opencv_##MODULE##_.get_symbol(TARGET_SYMBOL); \ - if (LOCAL_VAR == nullptr) { \ - throw ::std::runtime_error( \ - TARGET_SYMBOL " is unavailable on " #MODULE); \ - } +#define GET_SYMBOL(MODULE, LOCAL_VAR, TARGET_SYMBOL) \ + LOCAL_VAR = opencv_##MODULE##_.get_symbol(TARGET_SYMBOL); \ + if (LOCAL_VAR == nullptr) { \ + throw ::std::runtime_error( \ + TARGET_SYMBOL " is unavailable on " #MODULE); \ + } GET_SYMBOL(core, cvCreateMatHeader, "cvCreateMatHeader"); GET_SYMBOL(core, cvReleaseMat, "cvReleaseMat"); @@ -135,11 +132,11 @@ class OpenCV { GET_SYMBOL(core, cvRepeat, "cvRepeat"); GET_SYMBOL(imgproc, cvSmooth, "cvSmooth"); GET_SYMBOL(imgproc, cvResize, "cvResize"); - GET_SYMBOL(imgproc, cvNormalize, "cvNormalize"); //obsolete + GET_SYMBOL(imgproc, cvNormalize, "cvNormalize"); // obsolete GET_SYMBOL(imgproc, cvCvtColor, "cvCvtColor"); GET_SYMBOL(highgui, cvShowImage, "cvShowImage"); GET_SYMBOL(highgui, cvWaitKey, "cvWaitKey"); - #undef GET_SYMBOL +#undef GET_SYMBOL #endif } @@ -166,17 +163,14 @@ class OpenCV { cvReleaseMat_t cvReleaseMat; cvSetData_t cvSetData; cvSplit_t cvSplit; - cvPow_t cvPow; + cvPow_t cvPow; cvRepeat_t cvRepeat; cvSmooth_t cvSmooth; - cvResize_t cvResize; + cvResize_t cvResize; cvNormalize_t cvNormalize; cvShowImage_t cvShowImage; cvWaitKey_t cvWaitKey; cvCvtColor_t cvCvtColor; - - - }; int hl2cv_type(halide_type_t hl_type, int channel) { @@ -192,7 +186,7 @@ int hl2cv_type(halide_type_t hl_type, int channel) { } } -} // bb -} // ion +} // namespace bb +} // namespace ion #endif diff --git a/src/bb/sgm/rt.h b/src/bb/sgm/rt.h index 890c71b6..e791676d 100644 --- a/src/bb/sgm/rt.h +++ b/src/bb/sgm/rt.h @@ -10,14 +10,14 @@ namespace sgm { std::map extern_functions; class RegisterExtern { - public: - RegisterExtern(std::string key, Halide::ExternCFunction f) { - extern_functions[key] = f; - } +public: + RegisterExtern(std::string key, Halide::ExternCFunction f) { + extern_functions[key] = f; + } }; -} // image_io -} // bb -} // ion +} // namespace sgm +} // namespace bb +} // namespace ion #endif diff --git a/src/builder.cc b/src/builder.cc index d96eabf9..3a6c5306 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -33,14 +33,18 @@ std::map compute_output_files(const Halide: std::string to_string(Halide::Argument::Kind kind) { switch (kind) { - case Halide::Argument::Kind::InputScalar: return "InputScalar"; - case Halide::Argument::Kind::InputBuffer: return "InputBuffer"; - case Halide::Argument::Kind::OutputBuffer: return "OutputBuffer"; - default: return "Unknown"; + case Halide::Argument::Kind::InputScalar: + return "InputScalar"; + case Halide::Argument::Kind::InputBuffer: + return "InputBuffer"; + case Halide::Argument::Kind::OutputBuffer: + return "OutputBuffer"; + default: + return "Unknown"; } } -} // anonymous +} // namespace using json = nlohmann::json; @@ -51,78 +55,73 @@ struct Builder::Impl { std::map jit_externs; std::vector graphs; std::vector nodes; - std::vector>> disposers; + std::vector>> disposers; // Cacheable Halide::Pipeline pipeline; Halide::Callable callable; std::unique_ptr jit_ctx; - Halide::JITUserContext* jit_ctx_ptr; - std::vector args; + Halide::JITUserContext *jit_ctx_ptr; + std::vector args; - Impl() : jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) { + Impl() + : jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) { } ~Impl(); }; Builder::Builder() - : impl_(new Impl) -{ + : impl_(new Impl) { } -Builder::~Builder() -{ +Builder::~Builder() { } -Builder::Impl::~Impl() -{ +Builder::Impl::~Impl() { for (auto [bb_id, disposer] : disposers) { disposer(bb_id.c_str()); } } -Node Builder::add(const std::string& name) -{ +Node Builder::add(const std::string &name) { Node n(sole::uuid4().str(), name, impl_->target); impl_->nodes.push_back(n); return n; } -Node Builder::add(const std::string& name, const GraphID & graph_id) -{ +Node Builder::add(const std::string &name, const GraphID &graph_id) { Node n(sole::uuid4().str(), name, impl_->target, graph_id); impl_->nodes.push_back(n); return n; } -Graph Builder::add_graph(const std::string& name) { +Graph Builder::add_graph(const std::string &name) { Graph g(*this, name); impl_->graphs.push_back(g); return g; } -Builder& Builder::set_target(const Halide::Target& target) { +Builder &Builder::set_target(const Halide::Target &target) { impl_->target = target; return *this; } -Builder& Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) { +Builder &Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) { impl_->jit_ctx_ptr = user_context_ptr; return *this; } -Builder& Builder::with_bb_module(const std::string& module_name_or_path) { +Builder &Builder::with_bb_module(const std::string &module_name_or_path) { auto bb_module = std::make_shared(module_name_or_path); - auto register_extern = bb_module->get_symbol&)>("register_externs"); + auto register_extern = bb_module->get_symbol &)>("register_externs"); if (register_extern) { register_extern(impl_->jit_externs); - } impl_->bb_modules[module_name_or_path] = bb_module; return *this; } -void Builder::save(const std::string& file_name) { +void Builder::save(const std::string &file_name) { determine_and_validate(impl_->nodes); std::ofstream ofs(file_name); json j; @@ -132,7 +131,7 @@ void Builder::save(const std::string& file_name) { return; } -void Builder::load(const std::string& file_name) { +void Builder::load(const std::string &file_name) { std::ifstream ifs(file_name); json j; ifs >> j; @@ -141,7 +140,7 @@ void Builder::load(const std::string& file_name) { return; } -void Builder::compile(const std::string& function_name, const CompileOption& option) { +void Builder::compile(const std::string &function_name, const CompileOption &option) { using namespace Halide; // Build pipeline and module first @@ -193,7 +192,7 @@ void Builder::compile(const std::string& function_name, const CompileOption& opt } void Builder::run() { - if (!impl_->pipeline.defined()) { + if (!impl_->pipeline.defined()) { impl_->pipeline = lower(*this, impl_->nodes, false); if (!impl_->pipeline.defined()) { log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port."); @@ -204,10 +203,9 @@ void Builder::run() { if (!impl_->callable.defined()) { std::map jit_externs; for (auto bb : impl_->bb_modules) { - auto register_extern = bb.second->get_symbol&)>("register_externs"); + auto register_extern = bb.second->get_symbol &)>("register_externs"); if (register_extern) { register_extern(jit_externs); - } } impl_->pipeline.set_jit_externs(jit_externs); @@ -220,7 +218,7 @@ void Builder::run() { impl_->args.clear(); impl_->args.push_back(&impl_->jit_ctx_ptr); - const auto& args(generate_arguments_instance(inferred_args, impl_->nodes)); + const auto &args(generate_arguments_instance(inferred_args, impl_->nodes)); impl_->args.insert(impl_->args.end(), args.begin(), args.end()); } @@ -230,12 +228,12 @@ void Builder::run() { std::vector Builder::bb_names(void) { std::vector names; for (auto n : Halide::Internal::GeneratorRegistry::enumerate()) { - names.push_back(n); + names.push_back(n); } return names; } -std::vector Builder::bb_arginfos(const std::string& name) { +std::vector Builder::bb_arginfos(const std::string &name) { auto generator_names = Halide::Internal::GeneratorRegistry::enumerate(); if (std::find(generator_names.begin(), generator_names.end(), name) == generator_names.end()) { @@ -257,7 +255,7 @@ std::vector Builder::bb_arginfos(const std::string& name) { try { bb->build_pipeline(); - } catch (const Halide::CompileError& e) { + } catch (const Halide::CompileError &e) { log::error(e.what()); throw std::runtime_error(e.what()); } @@ -269,7 +267,7 @@ std::string Builder::bb_metadata(void) { std::vector md; for (auto n : Halide::Internal::GeneratorRegistry::enumerate()) { - md.push_back(Metadata(n)); + md.push_back(Metadata(n)); } json j(md); @@ -281,15 +279,15 @@ Target Builder::target() const { return impl_->target; } -const std::vector& Builder::nodes() const { +const std::vector &Builder::nodes() const { return impl_->nodes; } -std::vector& Builder::nodes() { +std::vector &Builder::nodes() { return impl_->nodes; } -const std::map& Builder::jit_externs() const { +const std::map &Builder::jit_externs() const { return impl_->jit_externs; } @@ -299,21 +297,19 @@ void Builder::print_loop_nest() { } } -void Builder::register_disposer(Impl *impl, const std::string& bb_id, const std::string& disposer_symbol) { +void Builder::register_disposer(Impl *impl, const std::string &bb_id, const std::string &disposer_symbol) { log::info("Builder::register_disposer"); - for (const auto& kv : impl->bb_modules) { - const auto& dm(kv.second); - auto disposer_ptr = dm->get_symbol(disposer_symbol); + for (const auto &kv : impl->bb_modules) { + const auto &dm(kv.second); + auto disposer_ptr = dm->get_symbol(disposer_symbol); if (disposer_ptr) { impl->disposers.push_back(std::make_tuple(bb_id, disposer_ptr)); } } } - -const Builder::Impl* Builder::impl_ptr() const { +const Builder::Impl *Builder::impl_ptr() const { return impl_.get(); } - -} //namespace ion +} // namespace ion diff --git a/src/c_ion.cc b/src/c_ion.cc index e98287ed..fcf9abe7 100644 --- a/src/c_ion.cc +++ b/src/c_ion.cc @@ -12,24 +12,23 @@ namespace { template std::vector> convert(ion_buffer_t *b, int n) { std::vector> bs(n); - for (int i=0; i*>(b[i]); + for (int i = 0; i < n; ++i) { + bs[i] = *reinterpret_cast *>(b[i]); } return bs; } -} +} // namespace // // ion_port_t // -int ion_port_create(ion_port_t *ptr, const char *key, ion_type_t type, int dim) -{ +int ion_port_create(ion_port_t *ptr, const char *key, ion_type_t type, int dim) { try { *ptr = reinterpret_cast(new Port(key, halide_type_t(static_cast(type.code), type.bits, type.lanes), dim)); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -40,16 +39,15 @@ int ion_port_create(ion_port_t *ptr, const char *key, ion_type_t type, int dim) return 0; } -int ion_port_create_with_index(ion_port_t *ptr, ion_port_t obj, int index) -{ +int ion_port_create_with_index(ion_port_t *ptr, ion_port_t obj, int index) { try { - auto p = new Port(*reinterpret_cast(obj)); + auto p = new Port(*reinterpret_cast(obj)); p->set_index(index); *ptr = reinterpret_cast(p); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -60,14 +58,13 @@ int ion_port_create_with_index(ion_port_t *ptr, ion_port_t obj, int index) return 0; } -int ion_port_destroy(ion_port_t obj) -{ +int ion_port_destroy(ion_port_t obj) { try { - delete reinterpret_cast(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -81,11 +78,11 @@ int ion_port_destroy(ion_port_t obj) #define ION_PORT_BIND_IMPL(T, POSTFIX) \ int ion_port_bind_##POSTFIX(ion_port_t obj, T v) { \ try { \ - reinterpret_cast(obj)->bind(v); \ - } catch (const Halide::Error& e) { \ + reinterpret_cast(obj)->bind(v); \ + } catch (const Halide::Error &e) { \ log::error(e.what()); \ return 1; \ - } catch (const std::exception& e) { \ + } catch (const std::exception &e) { \ log::error(e.what()); \ return 1; \ } catch (...) { \ @@ -96,66 +93,65 @@ int ion_port_destroy(ion_port_t obj) return 0; \ } -ION_PORT_BIND_IMPL(int8_t*, i8) -ION_PORT_BIND_IMPL(int16_t*, i16) -ION_PORT_BIND_IMPL(int32_t*, i32) -ION_PORT_BIND_IMPL(int64_t*, i64) -ION_PORT_BIND_IMPL(bool*, u1) -ION_PORT_BIND_IMPL(uint8_t*, u8) -ION_PORT_BIND_IMPL(uint16_t*, u16) -ION_PORT_BIND_IMPL(uint32_t*, u32) -ION_PORT_BIND_IMPL(uint64_t*, u64) -ION_PORT_BIND_IMPL(float*, f32) -ION_PORT_BIND_IMPL(double*, f64) +ION_PORT_BIND_IMPL(int8_t *, i8) +ION_PORT_BIND_IMPL(int16_t *, i16) +ION_PORT_BIND_IMPL(int32_t *, i32) +ION_PORT_BIND_IMPL(int64_t *, i64) +ION_PORT_BIND_IMPL(bool *, u1) +ION_PORT_BIND_IMPL(uint8_t *, u8) +ION_PORT_BIND_IMPL(uint16_t *, u16) +ION_PORT_BIND_IMPL(uint32_t *, u32) +ION_PORT_BIND_IMPL(uint64_t *, u64) +ION_PORT_BIND_IMPL(float *, f32) +ION_PORT_BIND_IMPL(double *, f64) #undef ION_PORT_BIND_IMPL -int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b) -{ +int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b) { try { // NOTE: Halide::Buffer class layout is safe to call Halide::Buffer::type() - auto type = reinterpret_cast*>(b)->type(); + auto type = reinterpret_cast *>(b)->type(); if (type.is_int()) { if (type.bits() == 8) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 16) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 32) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_uint()) { if (type.bits() == 1) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 8) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 16) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 32) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_float()) { if (type.bits() == 32) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(*reinterpret_cast*>(b)); + reinterpret_cast(obj)->bind(*reinterpret_cast *>(b)); } else { throw std::runtime_error("Unsupported bits number"); } } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -163,56 +159,54 @@ int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b) return 1; } - return 0; } -int ion_port_bind_buffer_array(ion_port_t obj, ion_buffer_t *bs, int n) -{ +int ion_port_bind_buffer_array(ion_port_t obj, ion_buffer_t *bs, int n) { try { // NOTE: Halide::Buffer class layout is safe to call Halide::Buffer::type() - auto type = reinterpret_cast*>(*bs)->type(); + auto type = reinterpret_cast *>(*bs)->type(); if (type.is_int()) { if (type.bits() == 8) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 16) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 32) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_uint()) { if (type.bits() == 1) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 8) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 16) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 32) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_float()) { if (type.bits() == 32) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else if (type.bits() == 64) { - reinterpret_cast(obj)->bind(convert(bs, n)); + reinterpret_cast(obj)->bind(convert(bs, n)); } else { throw std::runtime_error("Unsupported bits number"); } } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -220,20 +214,18 @@ int ion_port_bind_buffer_array(ion_port_t obj, ion_buffer_t *bs, int n) return 1; } - return 0; } // // ion_param_t // -int ion_param_create(ion_param_t *ptr, const char *key, const char *value) -{ +int ion_param_create(ion_param_t *ptr, const char *key, const char *value) { try { *ptr = reinterpret_cast(new Param(key, value)); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -244,14 +236,13 @@ int ion_param_create(ion_param_t *ptr, const char *key, const char *value) return 0; } -int ion_param_destroy(ion_param_t obj) -{ +int ion_param_destroy(ion_param_t obj) { try { - delete reinterpret_cast(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -265,14 +256,13 @@ int ion_param_destroy(ion_param_t obj) // // ion_node_t // -int ion_node_create(ion_node_t *ptr) -{ +int ion_node_create(ion_node_t *ptr) { try { *ptr = reinterpret_cast(new Node); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -283,14 +273,13 @@ int ion_node_create(ion_node_t *ptr) return 0; } -int ion_node_destroy(ion_node_t obj) -{ +int ion_node_destroy(ion_node_t obj) { try { - delete reinterpret_cast(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -299,17 +288,15 @@ int ion_node_destroy(ion_node_t obj) } return 0; - } -int ion_node_get_port(ion_node_t obj, const char *key, ion_port_t *port_ptr) -{ +int ion_node_get_port(ion_node_t obj, const char *key, ion_port_t *port_ptr) { try { - *port_ptr = reinterpret_cast(new Port((*reinterpret_cast(obj))[key])); - } catch (const Halide::Error& e) { + *port_ptr = reinterpret_cast(new Port((*reinterpret_cast(obj))[key])); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -320,18 +307,17 @@ int ion_node_get_port(ion_node_t obj, const char *key, ion_port_t *port_ptr) return 0; } -int ion_node_set_iports(ion_node_t obj, ion_port_t *ports_ptr, int ports_num) -{ +int ion_node_set_iports(ion_node_t obj, ion_port_t *ports_ptr, int ports_num) { try { std::vector ports(ports_num); - for (int i=0; i(ports_ptr[i]); + for (int i = 0; i < ports_num; ++i) { + ports[i] = *reinterpret_cast(ports_ptr[i]); } - reinterpret_cast(obj)->set_iports(ports); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->set_iports(ports); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -342,18 +328,17 @@ int ion_node_set_iports(ion_node_t obj, ion_port_t *ports_ptr, int ports_num) return 0; } -int ion_node_set_params(ion_node_t obj, ion_param_t *params_ptr, int params_num) -{ +int ion_node_set_params(ion_node_t obj, ion_param_t *params_ptr, int params_num) { try { std::vector params(params_num); - for (int i=0; i(params_ptr[i]); + for (int i = 0; i < params_num; ++i) { + params[i] = *reinterpret_cast(params_ptr[i]); } - reinterpret_cast(obj)->set_params(params); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->set_params(params); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -367,14 +352,13 @@ int ion_node_set_params(ion_node_t obj, ion_param_t *params_ptr, int params_num) // // ion_builder_t // -int ion_builder_create(ion_builder_t *ptr) -{ +int ion_builder_create(ion_builder_t *ptr) { try { *ptr = reinterpret_cast(new Builder); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -385,14 +369,13 @@ int ion_builder_create(ion_builder_t *ptr) return 0; } -int ion_builder_destroy(ion_builder_t obj) -{ +int ion_builder_destroy(ion_builder_t obj) { try { - delete reinterpret_cast(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -403,14 +386,13 @@ int ion_builder_destroy(ion_builder_t obj) return 0; } -int ion_builder_set_target(ion_builder_t obj, const char *target) -{ +int ion_builder_set_target(ion_builder_t obj, const char *target) { try { reinterpret_cast(obj)->set_target(Halide::Target(target)); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -421,14 +403,13 @@ int ion_builder_set_target(ion_builder_t obj, const char *target) return 0; } -int ion_builder_with_bb_module(ion_builder_t obj, const char *module_name) -{ +int ion_builder_with_bb_module(ion_builder_t obj, const char *module_name) { try { reinterpret_cast(obj)->with_bb_module(module_name); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -438,14 +419,13 @@ int ion_builder_with_bb_module(ion_builder_t obj, const char *module_name) return 0; } -int ion_builder_add_graph(ion_builder_t obj, const char *name, ion_graph_t *graph_ptr) -{ +int ion_builder_add_graph(ion_builder_t obj, const char *name, ion_graph_t *graph_ptr) { try { - *graph_ptr = reinterpret_cast(new Graph(reinterpret_cast(obj)->add_graph(name))); - } catch (const Halide::Error& e) { + *graph_ptr = reinterpret_cast(new Graph(reinterpret_cast(obj)->add_graph(name))); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -456,14 +436,13 @@ int ion_builder_add_graph(ion_builder_t obj, const char *name, ion_graph_t *grap return 0; } -int ion_builder_add_node(ion_builder_t obj, const char *key, ion_node_t *node_ptr) -{ +int ion_builder_add_node(ion_builder_t obj, const char *key, ion_node_t *node_ptr) { try { - *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(key))); - } catch (const Halide::Error& e) { + *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(key))); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -474,14 +453,13 @@ int ion_builder_add_node(ion_builder_t obj, const char *key, ion_node_t *node_pt return 0; } -int ion_builder_compile(ion_builder_t obj, const char *function_name, ion_builder_compile_option_t option) -{ +int ion_builder_compile(ion_builder_t obj, const char *function_name, ion_builder_compile_option_t option) { try { - reinterpret_cast(obj)->compile(function_name, Builder::CompileOption{option.output_directory}); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->compile(function_name, Builder::CompileOption{option.output_directory}); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -491,14 +469,13 @@ int ion_builder_compile(ion_builder_t obj, const char *function_name, ion_builde return 0; } -int ion_builder_load(ion_builder_t obj, const char *file_name) -{ +int ion_builder_load(ion_builder_t obj, const char *file_name) { try { - reinterpret_cast(obj)->load(file_name); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->load(file_name); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -509,14 +486,13 @@ int ion_builder_load(ion_builder_t obj, const char *file_name) return 0; } -int ion_builder_save(ion_builder_t obj, const char *file_name) -{ +int ion_builder_save(ion_builder_t obj, const char *file_name) { try { - reinterpret_cast(obj)->save(file_name); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->save(file_name); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -527,10 +503,9 @@ int ion_builder_save(ion_builder_t obj, const char *file_name) return 0; } -int ion_builder_bb_metadata(ion_builder_t obj, char *ptr, int n, int *ret_n) -{ +int ion_builder_bb_metadata(ion_builder_t obj, char *ptr, int n, int *ret_n) { try { - auto md = reinterpret_cast(obj)->bb_metadata(); + auto md = reinterpret_cast(obj)->bb_metadata(); if (ptr != nullptr) { auto copy_size = (std::min)(static_cast(n), md.size()); std::memcpy(ptr, md.c_str(), copy_size); @@ -539,7 +514,7 @@ int ion_builder_bb_metadata(ion_builder_t obj, char *ptr, int n, int *ret_n) *ret_n = static_cast(md.size()); } } - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return -1; } catch (...) { @@ -550,14 +525,13 @@ int ion_builder_bb_metadata(ion_builder_t obj, char *ptr, int n, int *ret_n) return 0; } -int ion_builder_run(ion_builder_t obj) -{ +int ion_builder_run(ion_builder_t obj) { try { - reinterpret_cast(obj)->run(); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->run(); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -569,7 +543,7 @@ int ion_builder_run(ion_builder_t obj) } template -Halide::Buffer *make_buffer(const std::vector& sizes) { +Halide::Buffer *make_buffer(const std::vector &sizes) { if (sizes.empty()) { auto p = new Halide::Buffer(); *p = Halide::Buffer::make_scalar(); @@ -580,18 +554,17 @@ Halide::Buffer *make_buffer(const std::vector& sizes) { } template -Halide::Buffer *make_buffer(void *data, const std::vector& sizes) { +Halide::Buffer *make_buffer(void *data, const std::vector &sizes) { if (sizes.empty()) { auto p = new Halide::Buffer(); - *p = Halide::Buffer::make_scalar(reinterpret_cast(data)); + *p = Halide::Buffer::make_scalar(reinterpret_cast(data)); return p; } else { return new Halide::Buffer(reinterpret_cast(data), sizes); } } -int ion_buffer_create(ion_buffer_t *ptr, ion_type_t type, int *sizes_, int dim) -{ +int ion_buffer_create(ion_buffer_t *ptr, ion_type_t type, int *sizes_, int dim) { try { std::vector sizes(dim); std::memcpy(sizes.data(), sizes_, dim * sizeof(int)); @@ -635,10 +608,10 @@ int ion_buffer_create(ion_buffer_t *ptr, ion_type_t type, int *sizes_, int dim) } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -649,8 +622,7 @@ int ion_buffer_create(ion_buffer_t *ptr, ion_type_t type, int *sizes_, int dim) return 0; } -int ion_buffer_create_with_data(ion_buffer_t *ptr, ion_type_t type, void *data, int *sizes_, int dim) -{ +int ion_buffer_create_with_data(ion_buffer_t *ptr, ion_type_t type, void *data, int *sizes_, int dim) { try { std::vector sizes(dim); std::memcpy(sizes.data(), sizes_, dim * sizeof(int)); @@ -694,10 +666,10 @@ int ion_buffer_create_with_data(ion_buffer_t *ptr, ion_type_t type, void *data, } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -708,16 +680,14 @@ int ion_buffer_create_with_data(ion_buffer_t *ptr, ion_type_t type, void *data, return 0; } - -int ion_buffer_destroy(ion_buffer_t obj) -{ +int ion_buffer_destroy(ion_buffer_t obj) { try { // NOTE: Halide::Buffer class layout is safe to be deleted as T=void - delete reinterpret_cast*>(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast *>(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -728,52 +698,51 @@ int ion_buffer_destroy(ion_buffer_t obj) return 0; } -int ion_buffer_write(ion_buffer_t obj, void *ptr, int size) -{ +int ion_buffer_write(ion_buffer_t obj, void *ptr, int size) { try { // NOTE: Halide::Buffer class layout is safe to call Halide::Buffer::type() - auto type = reinterpret_cast*>(obj)->type(); + auto type = reinterpret_cast *>(obj)->type(); if (type.is_int()) { if (type.bits() == 8) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 16) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 32) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 64) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_uint()) { if (type.bits() == 1) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 8) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 16) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 32) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 64) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_float()) { if (type.bits() == 32) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else if (type.bits() == 64) { - std::memcpy(reinterpret_cast*>(obj)->data(), ptr, size); + std::memcpy(reinterpret_cast *>(obj)->data(), ptr, size); } else { throw std::runtime_error("Unsupported bits number"); } } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -784,52 +753,51 @@ int ion_buffer_write(ion_buffer_t obj, void *ptr, int size) return 0; } -int ion_buffer_read(ion_buffer_t obj, void *ptr, int size) -{ +int ion_buffer_read(ion_buffer_t obj, void *ptr, int size) { try { // NOTE: Halide::Buffer class layout is safe to call Halide::Buffer::type() - auto type = reinterpret_cast*>(obj)->type(); + auto type = reinterpret_cast *>(obj)->type(); if (type.is_int()) { if (type.bits() == 8) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 16) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 32) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 64) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_uint()) { if (type.bits() == 1) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 8) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 16) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 32) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 64) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else { throw std::runtime_error("Unsupported bits number"); } } else if (type.is_float()) { if (type.bits() == 32) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else if (type.bits() == 64) { - std::memcpy(ptr, reinterpret_cast*>(obj)->data(), size); + std::memcpy(ptr, reinterpret_cast *>(obj)->data(), size); } else { throw std::runtime_error("Unsupported bits number"); } } else { throw std::runtime_error("Unsupported type code"); } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -840,15 +808,13 @@ int ion_buffer_read(ion_buffer_t obj, void *ptr, int size) return 0; } - -int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name) -{ +int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char *name) { try { - *ptr = reinterpret_cast(new Graph(*reinterpret_cast(obj), name)); - } catch (const Halide::Error& e) { + *ptr = reinterpret_cast(new Graph(*reinterpret_cast(obj), name)); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -859,14 +825,13 @@ int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name) return 0; } -int ion_graph_add_node(ion_graph_t obj, const char *name, ion_node_t *node_ptr) -{ +int ion_graph_add_node(ion_graph_t obj, const char *name, ion_node_t *node_ptr) { try { - *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(name))); - } catch (const Halide::Error& e) { + *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(name))); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -877,17 +842,17 @@ int ion_graph_add_node(ion_graph_t obj, const char *name, ion_node_t *node_ptr) return 0; } -int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* graphs_ptr, int graphs_num) { +int ion_graph_create_with_multiple(ion_graph_t *ptr, ion_graph_t *graphs_ptr, int graphs_num) { try { - auto sum_graph = *reinterpret_cast(graphs_ptr[0]); - for (int i=1; i(graphs_ptr[i]); + auto sum_graph = *reinterpret_cast(graphs_ptr[0]); + for (int i = 1; i < graphs_num; ++i) { + sum_graph += *reinterpret_cast(graphs_ptr[i]); } *ptr = reinterpret_cast(new Graph(sum_graph)); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -898,14 +863,13 @@ int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* graphs_ptr, i return 0; } -int ion_graph_run(ion_graph_t obj) -{ +int ion_graph_run(ion_graph_t obj) { try { - reinterpret_cast(obj)->run(); - } catch (const Halide::Error& e) { + reinterpret_cast(obj)->run(); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { @@ -916,13 +880,13 @@ int ion_graph_run(ion_graph_t obj) return 0; } -int ion_graph_destroy(ion_graph_t obj){ +int ion_graph_destroy(ion_graph_t obj) { try { - delete reinterpret_cast(obj); - } catch (const Halide::Error& e) { + delete reinterpret_cast(obj); + } catch (const Halide::Error &e) { log::error(e.what()); return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { log::error(e.what()); return 1; } catch (...) { diff --git a/src/dynamic_module.h b/src/dynamic_module.h index ed0bdd83..53af5fc0 100644 --- a/src/dynamic_module.h +++ b/src/dynamic_module.h @@ -22,10 +22,10 @@ #include "log.h" namespace { -bool has_prefix_and_ext(const std::string& n) { +bool has_prefix_and_ext(const std::string &n) { return n.find(ION_DYNAMIC_MODULE_PREFIX) != std::string::npos && n.find(ION_DYNAMIC_MODULE_EXT) != std::string::npos; } -} +} // namespace namespace ion { @@ -37,7 +37,7 @@ class DynamicModule { using Handle = void *; #endif - DynamicModule(const std::string& module_name_or_path, bool essential = true, bool lazy_load = false) { + DynamicModule(const std::string &module_name_or_path, bool essential = true, bool lazy_load = false) { if (module_name_or_path == "") { handle_ = nullptr; return; diff --git a/src/graph.cc b/src/graph.cc index f6a43da1..d19e938f 100644 --- a/src/graph.cc +++ b/src/graph.cc @@ -7,7 +7,7 @@ namespace ion { struct Graph::Impl { - Builder & builder; + Builder &builder; std::string name; GraphID id; std::vector nodes; @@ -15,54 +15,46 @@ struct Graph::Impl { Halide::Pipeline pipeline; Halide::Callable callable; std::unique_ptr jit_ctx; - Halide::JITUserContext* jit_ctx_ptr; - std::vector args; + Halide::JITUserContext *jit_ctx_ptr; + std::vector args; - Impl(Builder & b, const std::string& n) - : id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) - { + Impl(Builder &b, const std::string &n) + : id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) { } }; -Graph::Graph() -{ +Graph::Graph() { } -Graph::Graph(Builder & builder, const std::string& name) - : impl_(new Impl(builder, name)) -{ +Graph::Graph(Builder &builder, const std::string &name) + : impl_(new Impl(builder, name)) { } -Graph& Graph::operator+=(const Graph& rhs) -{ +Graph &Graph::operator+=(const Graph &rhs) { impl_->nodes.insert(impl_->nodes.end(), rhs.impl_->nodes.begin(), rhs.impl_->nodes.end()); return *this; } -Graph operator+(const Graph& lhs, const Graph& rhs) -{ +Graph operator+(const Graph &lhs, const Graph &rhs) { Graph g(lhs.impl_->builder); g += lhs; g += rhs; return g; } -Node Graph::add(const std::string& name) -{ - auto n = impl_->builder.add(name,impl_->id); +Node Graph::add(const std::string &name) { + auto n = impl_->builder.add(name, impl_->id); impl_->nodes.push_back(n); return n; } -Graph& Graph::set_jit_context(Halide::JITUserContext *user_context_ptr) { +Graph &Graph::set_jit_context(Halide::JITUserContext *user_context_ptr) { impl_->jit_ctx_ptr = user_context_ptr; return *this; } - -void Graph::run() -{ - if (!impl_->pipeline.defined()) { +void Graph::run() { + if (!impl_->pipeline.defined()) { impl_->pipeline = lower(impl_->builder, impl_->nodes, false); if (!impl_->pipeline.defined()) { log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port."); @@ -81,20 +73,19 @@ void Graph::run() impl_->args.clear(); impl_->args.push_back(&impl_->jit_ctx_ptr); - const auto& args(generate_arguments_instance(inferred_args, impl_->nodes)); + const auto &args(generate_arguments_instance(inferred_args, impl_->nodes)); impl_->args.insert(impl_->args.end(), args.begin(), args.end()); } impl_->callable.call_argv_fast(impl_->args.size(), impl_->args.data()); } -const std::vector& Graph::nodes() const { +const std::vector &Graph::nodes() const { return impl_->nodes; } -std::vector& Graph::nodes() { +std::vector &Graph::nodes() { return impl_->nodes; } - -} // namespace ion +} // namespace ion diff --git a/src/log.cc b/src/log.cc index 29d217d2..8ee7f5da 100644 --- a/src/log.cc +++ b/src/log.cc @@ -1,5 +1,5 @@ #ifndef FMT_CONSTEVAL -#define FMT_CONSTEVAL // To prevent format string is evaluated as constexpr +#define FMT_CONSTEVAL // To prevent format string is evaluated as constexpr #endif #include "spdlog/cfg/helpers.h" #include "spdlog/details/os.h" @@ -23,35 +23,34 @@ bool should_log(level::level_enum level) { } } -} // log -} // ion +} // namespace log +} // namespace ion namespace { struct Logger { - Logger() - { - auto log_level = spdlog::level::off; - auto env_val = spdlog::details::os::getenv("ION_LOG_LEVEL"); - if (env_val.empty()) { - return; - } + Logger() { + auto log_level = spdlog::level::off; + auto env_val = spdlog::details::os::getenv("ION_LOG_LEVEL"); + if (env_val.empty()) { + return; + } - log_level = spdlog::level::from_str(env_val); + log_level = spdlog::level::from_str(env_val); - auto console_sink = std::make_shared(); - console_sink->set_level(log_level); + auto console_sink = std::make_shared(); + console_sink->set_level(log_level); - auto file_sink = std::make_shared("logs/ion.log", false); - file_sink->set_level(log_level); + auto file_sink = std::make_shared("logs/ion.log", false); + file_sink->set_level(log_level); - auto logger = std::make_shared("ion", spdlog::sinks_init_list{console_sink, file_sink}); - logger->set_level(log_level); + auto logger = std::make_shared("ion", spdlog::sinks_init_list{console_sink, file_sink}); + logger->set_level(log_level); - logger->debug("ion-kit version is {}", ION_KIT_VERSION); + logger->debug("ion-kit version is {}", ION_KIT_VERSION); - spdlog::register_logger(logger); - } + spdlog::register_logger(logger); + } } logger; -} // anonymous +} // namespace diff --git a/src/log.h b/src/log.h index f08a22d9..750c9f2b 100644 --- a/src/log.h +++ b/src/log.h @@ -2,7 +2,7 @@ #define ION_LOG_H #ifndef FMT_CONSTEVAL -#define FMT_CONSTEVAL // To prevent format string is evaluated as constexpr +#define FMT_CONSTEVAL // To prevent format string is evaluated as constexpr #endif #include "spdlog/spdlog.h" @@ -25,14 +25,32 @@ enum level_enum : int { std::shared_ptr get(); bool should_log(level::level_enum level); -template inline void critical(Args... args) { if (get()) get()->critical(args...); } -template inline void error (Args... args) { if (get()) get()->error (args...); } -template inline void warn (Args... args) { if (get()) get()->warn (args...); } -template inline void info (Args... args) { if (get()) get()->info (args...); } -template inline void debug (Args... args) { if (get()) get()->debug (args...); } -template inline void trace (Args... args) { if (get()) get()->trace (args...); } +template +inline void critical(Args... args) { + if (get()) get()->critical(args...); +} +template +inline void error(Args... args) { + if (get()) get()->error(args...); +} +template +inline void warn(Args... args) { + if (get()) get()->warn(args...); +} +template +inline void info(Args... args) { + if (get()) get()->info(args...); +} +template +inline void debug(Args... args) { + if (get()) get()->debug(args...); +} +template +inline void trace(Args... args) { + if (get()) get()->trace(args...); +} -} // log -} // ion +} // namespace log +} // namespace ion #endif diff --git a/src/lower.cc b/src/lower.cc index dfb257d9..51394028 100644 --- a/src/lower.cc +++ b/src/lower.cc @@ -8,13 +8,13 @@ namespace ion { namespace { -bool is_free(const std::string& pn) { +bool is_free(const std::string &pn) { return pn.find("_ion_iport_") != std::string::npos; } -std::tuple find_ith_input(const std::vector& arginfos, int i) { +std::tuple find_ith_input(const std::vector &arginfos, int i) { int j = 0; - for (const auto& arginfo : arginfos) { + for (const auto &arginfo : arginfos) { if (arginfo.dir != Halide::Internal::ArgInfoDirection::Input) { continue; } @@ -29,20 +29,20 @@ std::tuple find_ith_input(co return std::make_tuple(Halide::Internal::AbstractGenerator::ArgInfo(), false); } -bool is_ready(const std::vector& sorted, const Node& n) { +bool is_ready(const std::vector &sorted, const Node &n) { bool ready = true; - for (const auto& [pn, port] : n.iports()) { + for (const auto &[pn, port] : n.iports()) { // This port has predecessor dependency. Always ready to add. if (!port.has_pred()) { continue; } - const auto& port_(port); // This is workaround for Clang-14 (MacOS) + const auto &port_(port); // This is workaround for Clang-14 (MacOS) // Check port dependent node is already added ready &= std::find_if(sorted.begin(), sorted.end(), - [&](const Node& n) { - return n.id() == port_.pred_id(); + [&](const Node &n) { + return n.id() == port_.pred_id(); }) != sorted.end(); } return ready; @@ -50,14 +50,18 @@ bool is_ready(const std::vector& sorted, const Node& n) { std::string to_string(Halide::Argument::Kind kind) { switch (kind) { - case Halide::Argument::Kind::InputScalar: return "InputScalar"; - case Halide::Argument::Kind::InputBuffer: return "InputBuffer"; - case Halide::Argument::Kind::OutputBuffer: return "OutputBuffer"; - default: return "Unknown"; + case Halide::Argument::Kind::InputScalar: + return "InputScalar"; + case Halide::Argument::Kind::InputBuffer: + return "InputBuffer"; + case Halide::Argument::Kind::OutputBuffer: + return "OutputBuffer"; + default: + return "Unknown"; } } -void topological_sort(std::vector& nodes) { +void topological_sort(std::vector &nodes) { std::vector sorted; if (nodes.empty()) { return; @@ -80,9 +84,9 @@ void topological_sort(std::vector& nodes) { nodes.swap(sorted); } -} // anonymous +} // namespace -void determine_and_validate(std::vector& nodes) { +void determine_and_validate(std::vector &nodes) { auto generator_names = Halide::Internal::GeneratorRegistry::enumerate(); @@ -94,10 +98,10 @@ void determine_and_validate(std::vector& nodes) { auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target()))); // Validate and set parameters - for (const auto& p : n.params()) { + for (const auto &p : n.params()) { try { bb->set_generatorparam_value(p.key(), p.val()); - } catch (const Halide::CompileError& e) { + } catch (const Halide::CompileError &e) { auto msg = fmt::format("BuildingBlock \"{}\" has no parameter \"{}\"", n.name(), p.key()); log::error(msg); throw std::runtime_error(msg); @@ -106,18 +110,18 @@ void determine_and_validate(std::vector& nodes) { try { bb->build_pipeline(); - } catch (const Halide::CompileError& e) { + } catch (const Halide::CompileError &e) { log::error(e.what()); throw std::runtime_error(e.what()); } - const auto& arginfos(bb->arginfos()); + const auto &arginfos(bb->arginfos()); // validate input port auto i = 0; - for (auto& [pn, port] : n.iports()) { + for (auto &[pn, port] : n.iports()) { if (is_free(pn)) { - const auto& [arginfo, found] = find_ith_input(arginfos, i); + const auto &[arginfo, found] = find_ith_input(arginfos, i); if (!found) { auto msg = fmt::format("BuildingBlock \"{}\" has no input #{}", n.name(), i); log::error(msg); @@ -128,9 +132,9 @@ void determine_and_validate(std::vector& nodes) { pn = arginfo.name; } - const auto& pn_(pn); // This is workaround for Clang-14 (MacOS) + const auto &pn_(pn); // This is workaround for Clang-14 (MacOS) if (!std::count_if(arginfos.begin(), arginfos.end(), - [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Input == arginfo.dir; })) { + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo) { return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Input == arginfo.dir; })) { auto msg = fmt::format("BuildingBlock \"{}\" has no input \"{}\"", n.name(), pn); log::error(msg); throw std::runtime_error(msg); @@ -140,10 +144,10 @@ void determine_and_validate(std::vector& nodes) { } // validate output - for (const auto& [pn, port] : n.oports()) { - const auto& pn_(pn); // This is workaround for Clang-14 (MacOS) + for (const auto &[pn, port] : n.oports()) { + const auto &pn_(pn); // This is workaround for Clang-14 (MacOS) if (!std::count_if(arginfos.begin(), arginfos.end(), - [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo) { return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { auto msg = fmt::format("BuildingBlock \"{}\" has no output \"{}\"", n.name(), pn); log::error(msg); throw std::runtime_error(msg); @@ -152,26 +156,26 @@ void determine_and_validate(std::vector& nodes) { } } -std::vector generate_arguments_instance(const std::vector& inferred_args, const std::vector& nodes) { - std::vector instances(inferred_args.size(), nullptr); +std::vector generate_arguments_instance(const std::vector &inferred_args, const std::vector &nodes) { + std::vector instances(inferred_args.size(), nullptr); // Input - for (const auto& node : nodes) { - for (const auto& [pn, port] : node.iports()) { + for (const auto &node : nodes) { + for (const auto &[pn, port] : node.iports()) { if (port.has_pred()) { continue; } auto i = 0; for (auto arg : port.as_argument()) { - auto it = std::find_if(inferred_args.begin(), inferred_args.end(), [arg](const Halide::Argument& inferred_arg) { return inferred_arg.name == arg.name; }); + auto it = std::find_if(inferred_args.begin(), inferred_args.end(), [arg](const Halide::Argument &inferred_arg) { return inferred_arg.name == arg.name; }); if (it == inferred_args.end()) { log::warn("Argument \"{}\" is not found in the inferred arguements", arg.name); i++; continue; } - auto idx = it-inferred_args.begin(); + auto idx = it - inferred_args.begin(); log::debug("Inserted \"{}\" instance at #{}", arg.name, idx); instances[idx] = port.as_instance()[i++]; } @@ -179,9 +183,9 @@ std::vector generate_arguments_instance(const std::vector generate_arguments_instance(const std::vector generate_arguments_instance(const std::vector& nodes, bool implicit_output) { +Halide::Pipeline lower(Builder builder, std::vector &nodes, bool implicit_output) { log::info("Start building pipeline"); @@ -233,33 +237,33 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ params["bb_id"] = to_string(n.id()); // User defined parameter - for (const auto& p : n.params()) { - params[p.key()] = p.val(); + for (const auto &p : n.params()) { + params[p.key()] = p.val(); } bb->set_generatorparam_values(params); bbs[n.id()] = std::move(bb); } // Assigning ports and build pipeline - for (size_t i=0; iarginfos(); - for (const auto& [pn, port] : n.iports()) { + for (const auto &[pn, port] : n.iports()) { // Find arginfo - auto it = std::find_if(arginfos.begin(), arginfos.end(), [&pn=pn](const ArgInfo& arginfo) { return arginfo.name == pn; }); + auto it = std::find_if(arginfos.begin(), arginfos.end(), [&pn = pn](const ArgInfo &arginfo) { return arginfo.name == pn; }); if (it == arginfos.end()) { auto msg = fmt::format("Argument {} is not defined in node {}", pn, n.name()); log::error(msg); throw std::runtime_error(msg); } - const auto& arginfo = *it; + const auto &arginfo = *it; auto index = port.index(); if (port.has_pred()) { - const auto& pred_bb(bbs[port.pred_id()]); + const auto &pred_bb(bbs[port.pred_id()]); auto fs = pred_bb->output_func(port.pred_name()); if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) { bb->bind_input(arginfo.name, fs); @@ -269,7 +273,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ bb->bind_input(arginfo.name, fs); } else { // access to Port[index] - if (index>=static_cast(fs.size())){ + if (index >= static_cast(fs.size())) { throw std::runtime_error("Port index out of range: " + port.pred_name()); } bb->bind_input(arginfo.name, {fs[index]}); @@ -300,9 +304,9 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ if (implicit_output) { // Collects all output which is never referenced. // This mode is used for AOT compilation - std::unordered_map, NodeID::StringIDHash> referenced; - for (const auto& n : nodes) { - for (const auto& [pn, port] : n.iports()) { + std::unordered_map, NodeID::StringIDHash> referenced; + for (const auto &n : nodes) { + for (const auto &[pn, port] : n.iports()) { if (port.has_pred()) { for (const auto &f : bbs[port.pred_id()]->output_func(port.pred_name())) { referenced[port.pred_id()].emplace_back(f.name()); @@ -311,7 +315,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ } } - for (const auto& node : nodes) { + for (const auto &node : nodes) { auto node_id = node.id(); for (auto arginfo : bbs[node_id]->arginfos()) { if (arginfo.dir != Halide::Internal::ArgInfoDirection::Output) { @@ -320,7 +324,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ // This is not output // It is not dereferenced, then treat as outputs - const auto& dv = referenced[node_id]; + const auto &dv = referenced[node_id]; for (auto f : bbs[node_id]->output_func(arginfo.name)) { auto it = std::find(dv.begin(), dv.end(), f.name()); @@ -334,26 +338,25 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ } else { // Collects all output which is bound with buffer. // This mode is used for JIT - for (const auto& node : nodes) { - for (const auto& [pn, port] : node.oports()) { - const auto& port_instances(port.as_instance()); + for (const auto &node : nodes) { + for (const auto &[pn, port] : node.oports()) { + const auto &port_instances(port.as_instance()); if (port_instances.empty()) { continue; } - const auto& pred_bb(bbs[port.pred_id()]); + const auto &pred_bb(bbs[port.pred_id()]); // Validate port exists - const auto& port_(port); // This is workaround for Clang-14 (MacOS) - const auto& pred_arginfos(pred_bb->arginfos()); + const auto &port_(port); // This is workaround for Clang-14 (MacOS) + const auto &pred_arginfos(pred_bb->arginfos()); if (!std::count_if(pred_arginfos.begin(), pred_arginfos.end(), - [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return port_.pred_name() == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo) { return port_.pred_name() == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { auto msg = fmt::format("BuildingBlock \"{}\" has no output \"{}\"", pred_bb->name(), port.pred_name()); log::error(msg); throw std::runtime_error(msg); } - auto fs(bbs[port.pred_id()]->output_func(port.pred_name())); output_funcs.insert(output_funcs.end(), fs.begin(), fs.end()); } @@ -367,4 +370,4 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ return Halide::Pipeline(output_funcs); } -} // namespace ion +} // namespace ion diff --git a/src/lower.h b/src/lower.h index 92a2c1d7..cabdb43b 100644 --- a/src/lower.h +++ b/src/lower.h @@ -8,17 +8,17 @@ namespace Halide { class Argument; class Pipeline; -} +} // namespace Halide namespace ion { class Node; -void determine_and_validate(std::vector& nodes); +void determine_and_validate(std::vector &nodes); -std::vector generate_arguments_instance(const std::vector& inferred_args, const std::vector& nodes); +std::vector generate_arguments_instance(const std::vector &inferred_args, const std::vector &nodes); -Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_output); +Halide::Pipeline lower(Builder builder, std::vector &nodes, bool implicit_output); -} // namespace ion +} // namespace ion -#endif // ION_LOWER_H +#endif // ION_LOWER_H diff --git a/src/metadata.cc b/src/metadata.cc index cbee34c7..84cba9cf 100644 --- a/src/metadata.cc +++ b/src/metadata.cc @@ -9,31 +9,31 @@ namespace { -std::string unquote(const std::string& s) { +std::string unquote(const std::string &s) { if (s.size() < 2) { return s; } - size_t epos = s.size()-1; + size_t epos = s.size() - 1; if (s[0] == '"' && s[epos] == '"') { - return s.substr(1, s.size()-2); + return s.substr(1, s.size() - 2); } else { return s; } } -} // anonymous +} // namespace namespace ion { using json = nlohmann::json; -PortMD::PortMD(const std::string& n, const std::vector& ts, int d) - : name(n), types(ts), dimension(d) -{} +PortMD::PortMD(const std::string &n, const std::vector &ts, int d) + : name(n), types(ts), dimension(d) { +} -void to_json(json& j, const PortMD& v) { +void to_json(json &j, const PortMD &v) { j["name"] = v.name; std::vector types; for (auto t : v.types) { @@ -43,7 +43,7 @@ void to_json(json& j, const PortMD& v) { j["dimension"] = v.dimension; } -void from_json(const json& j, PortMD& v) { +void from_json(const json &j, PortMD &v) { v.name = j["name"].get(); auto types = j["types"].get>(); for (auto t : types) { @@ -52,25 +52,25 @@ void from_json(const json& j, PortMD& v) { v.dimension = j["dimension"]; } -ParamMD::ParamMD(const std::string& n, const std::string& dv, const std::string& ct, const std::string& td) - : name(n), default_value(dv), c_type(ct), type_decls(td) -{} +ParamMD::ParamMD(const std::string &n, const std::string &dv, const std::string &ct, const std::string &td) + : name(n), default_value(dv), c_type(ct), type_decls(td) { +} -void to_json(json& j, const ParamMD& v) { +void to_json(json &j, const ParamMD &v) { j["name"] = v.name; j["c_type"] = v.c_type; j["type_decls"] = v.type_decls; if (v.c_type.find("uint8_t") == 0) { - j["default_value"] = std::to_string(*reinterpret_cast(v.default_value.c_str())); + j["default_value"] = std::to_string(*reinterpret_cast(v.default_value.c_str())); } else if (v.c_type.find("uint8_t") == 0) { - j["default_value"] = std::to_string(*reinterpret_cast(v.default_value.c_str())); + j["default_value"] = std::to_string(*reinterpret_cast(v.default_value.c_str())); } else { j["default_value"] = v.default_value; } } -void from_json(const json& j, ParamMD& v) { +void from_json(const json &j, ParamMD &v) { v.name = j["name"].get(); v.c_type = j["c_type"]; v.type_decls = j["type_decls"]; @@ -88,9 +88,8 @@ void from_json(const json& j, ParamMD& v) { } } -Metadata::Metadata(const std::string& n) - : name(n) -{ +Metadata::Metadata(const std::string &n) + : name(n) { auto bb = Halide::Internal::GeneratorRegistry::create(n, Halide::GeneratorContext(Halide::get_host_target())); for (auto arginfo : bb->arginfos()) { @@ -112,18 +111,18 @@ Metadata::Metadata(const std::string& n) // } } -void to_json(json& j, const Metadata& v) { +void to_json(json &j, const Metadata &v) { j["name"] = v.name; j["inputs"] = v.inputs; j["outputs"] = v.outputs; j["params"] = v.params; } -void from_json(const json& j, Metadata& v) { +void from_json(const json &j, Metadata &v) { v.name = j["name"].get(); v.inputs = j["inputs"].get>(); v.outputs = j["outputs"].get>(); v.params = j["params"].get>(); } -} // namespace ion +} // namespace ion diff --git a/src/metadata.h b/src/metadata.h index eecfab11..d921dd0b 100644 --- a/src/metadata.h +++ b/src/metadata.h @@ -13,42 +13,44 @@ namespace ion { using json = nlohmann::json; struct PortMD { - friend void to_json(json&, const PortMD&); - friend void from_json(const json&, PortMD&); + friend void to_json(json &, const PortMD &); + friend void from_json(const json &, PortMD &); std::string name; std::vector types; int dimension; - PortMD() {} - PortMD(const std::string& n, const std::vector& ts, int d); + PortMD() { + } + PortMD(const std::string &n, const std::vector &ts, int d); }; struct ParamMD { - friend void to_json(json&, const ParamMD&); - friend void from_json(const json&, ParamMD&); + friend void to_json(json &, const ParamMD &); + friend void from_json(const json &, ParamMD &); std::string name; std::string default_value; std::string c_type; std::string type_decls; - ParamMD() {} - ParamMD(const std::string& n, const std::string& dv, const std::string& ct, const std::string& td); + ParamMD() { + } + ParamMD(const std::string &n, const std::string &dv, const std::string &ct, const std::string &td); }; struct Metadata { - friend void to_json(json&, const Metadata&); - friend void from_json(const json&, Metadata&); + friend void to_json(json &, const Metadata &); + friend void from_json(const json &, Metadata &); std::string name; std::vector inputs; std::vector outputs; std::vector params; - Metadata(const std::string& n); + Metadata(const std::string &n); }; -} //namespace ion +} // namespace ion #endif diff --git a/src/node.cc b/src/node.cc index f46bec25..05180247 100644 --- a/src/node.cc +++ b/src/node.cc @@ -4,10 +4,8 @@ namespace ion { - -Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_) - : id(id_), name(name_), target(target_), params(), ports() -{ +Node::Impl::Impl(const NodeID &id_, const std::string &name_, const Halide::Target &target_) + : id(id_), name(name_), target(target_), params(), ports() { auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_))); if (!bb) { log::error("BuildingBlock {} is not found", name_); @@ -17,9 +15,8 @@ Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Targ arginfos = bb->arginfos(); } -Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID& graph_id_) - : id(id_), name(name_), target(target_), params(), ports(), graph_id(graph_id_) -{ +Node::Impl::Impl(const NodeID &id_, const std::string &name_, const Halide::Target &target_, const GraphID &graph_id_) + : id(id_), name(name_), target(target_), params(), ports(), graph_id(graph_id_) { auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_))); if (!bb) { log::error("BuildingBlock {} is not found", name_); @@ -29,14 +26,14 @@ Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Targ arginfos = bb->arginfos(); } -void Node::set_iports(const std::vector& ports) { +void Node::set_iports(const std::vector &ports) { impl_->ports.erase(std::remove_if(impl_->ports.begin(), impl_->ports.end(), - [&](const Port &p) { return p.has_succ_by_nid(this->id());}), + [&](const Port &p) { return p.has_succ_by_nid(this->id()); }), impl_->ports.end()); size_t i = 0; - for (auto& port : ports) { + for (auto &port : ports) { // TODO: Validation is better to be done lazily after BuildingBlock::configure // // if (info.dir == Halide::Internal::ArgInfoDirection::Output) { @@ -50,7 +47,7 @@ void Node::set_iports(const std::vector& ports) { // NOTE: Is succ_chans name OK to be just leave as it is? port.impl_->succ_chans.insert({id(), "_ion_iport_" + std::to_string(i)}); - port.impl_ ->graph_id = impl_->graph_id; + port.impl_->graph_id = impl_->graph_id; impl_->ports.push_back(port); i++; @@ -58,20 +55,20 @@ void Node::set_iports(const std::vector& ports) { } void Node::set_iport(Port port) { - port.impl_ ->graph_id = impl_->graph_id; + port.impl_->graph_id = impl_->graph_id; port.impl_->succ_chans.insert({id(), port.pred_name()}); impl_->ports.push_back(port); } -void Node::set_iport(const std::string& name, Port port) { - port.impl_ ->graph_id = impl_->graph_id; +void Node::set_iport(const std::string &name, Port port) { + port.impl_->graph_id = impl_->graph_id; port.impl_->succ_chans.insert({id(), name}); impl_->ports.push_back(port); } -Port Node::operator[](const std::string& name) { +Port Node::operator[](const std::string &name) { auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(), - [&](const Port& p){ return p.pred_id() == impl_->id && p.pred_name() == name; }); + [&](const Port &p) { return p.pred_id() == impl_->id && p.pred_name() == name; }); if (it == impl_->ports.end()) { // This is output port which is never referenced. // Bind myself as a predecessor and register @@ -84,10 +81,10 @@ Port Node::operator[](const std::string& name) { } } -Port Node::iport(const std::string& pn) { - for (const auto& p: impl_->ports) { +Port Node::iport(const std::string &pn) { + for (const auto &p : impl_->ports) { auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), - [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id && std::get<1>(c) == pn; }); + [&](const Port::Channel &c) { return std::get<0>(c) == impl_->id && std::get<1>(c) == pn; }); if (it != p.impl_->succ_chans.end()) { return p; } @@ -100,9 +97,9 @@ Port Node::iport(const std::string& pn) { std::vector> Node::iports() const { std::vector> iports; - for (const auto& p: impl_->ports) { + for (const auto &p : impl_->ports) { auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), - [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; }); + [&](const Port::Channel &c) { return std::get<0>(c) == impl_->id; }); if (it != p.impl_->succ_chans.end()) { iports.push_back(std::make_tuple(std::get<1>(*it), p)); } @@ -110,39 +107,38 @@ std::vector> Node::iports() const { return iports; } - std::vector> Node::unbound_iports() const { - std::vector> unbound_iports; - int iports_size = 0; + std::vector> unbound_iports; + int iports_size = 0; - for (const auto& p: impl_->ports) { + for (const auto &p : impl_->ports) { auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), - [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; }); + [&](const Port::Channel &c) { return std::get<0>(c) == impl_->id; }); if (it != p.impl_->succ_chans.end()) { - iports_size+=1; + iports_size += 1; } } - int iports_idx = 0; - for (auto & arginfo: impl_->arginfos){ - if (arginfo.dir == Halide::Internal::ArgInfoDirection::Input) { - if(iports_idx>=iports_size){ - Port port("_ion_iport_" + std::to_string(iports_idx), arginfo.types.front()); - port.impl_->dimensions = arginfo.dimensions; - unbound_iports.push_back(std::make_tuple(arginfo.name, port)); - } - iports_idx ++; - } - } - return unbound_iports; + int iports_idx = 0; + for (auto &arginfo : impl_->arginfos) { + if (arginfo.dir == Halide::Internal::ArgInfoDirection::Input) { + if (iports_idx >= iports_size) { + Port port("_ion_iport_" + std::to_string(iports_idx), arginfo.types.front()); + port.impl_->dimensions = arginfo.dimensions; + unbound_iports.push_back(std::make_tuple(arginfo.name, port)); + } + iports_idx++; + } + } + return unbound_iports; } void Node::set_oport(Port port) { - port.impl_ ->graph_id = impl_->graph_id; - impl_->ports.push_back(port); + port.impl_->graph_id = impl_->graph_id; + impl_->ports.push_back(port); } -Port Node::oport(const std::string& pn) { +Port Node::oport(const std::string &pn) { return this->operator[](pn); // TODO: It is better to just return exisitng output port? @@ -161,7 +157,7 @@ Port Node::oport(const std::string& pn) { std::vector> Node::oports() const { std::vector> oports; - for (const auto& p: impl_->ports) { + for (const auto &p : impl_->ports) { if (id() == p.pred_id()) { oports.push_back(std::make_tuple(p.pred_name(), p)); } @@ -170,47 +166,47 @@ std::vector> Node::oports() const { } std::vector> Node::unbound_oports() const { - std::vector> unbound_oports; - int oports_size = 0; - - for (const auto& p: impl_->ports) { - if (id() == p.pred_id()) { - oports_size +=1; - } - } - int oports_idx = 0; - for (auto & arginfo: impl_->arginfos){ - if (arginfo.dir == Halide::Internal::ArgInfoDirection::Output) { - if(oports_idx>=oports_size){ - Port port(id(), arginfo.name); - port.impl_ ->type = arginfo.types.front(); - port.impl_->dimensions = arginfo.dimensions; - unbound_oports.push_back(std::make_tuple(arginfo.name, port)); - } - oports_idx ++; - } - } - return unbound_oports; + std::vector> unbound_oports; + int oports_size = 0; + + for (const auto &p : impl_->ports) { + if (id() == p.pred_id()) { + oports_size += 1; + } + } + int oports_idx = 0; + for (auto &arginfo : impl_->arginfos) { + if (arginfo.dir == Halide::Internal::ArgInfoDirection::Output) { + if (oports_idx >= oports_size) { + Port port(id(), arginfo.name); + port.impl_->type = arginfo.types.front(); + port.impl_->dimensions = arginfo.dimensions; + unbound_oports.push_back(std::make_tuple(arginfo.name, port)); + } + oports_idx++; + } + } + return unbound_oports; } -void Node::detect_data_hazard ()const { - std::vector> oports = Node::oports() ; - std::vector> iports = Node::iports() ; +void Node::detect_data_hazard() const { + std::vector> oports = Node::oports(); + std::vector> iports = Node::iports(); std::set> address_set; - for (auto& [pn, port] :oports) { - for(auto& [i, t] : port.impl_->bound_address){ + for (auto &[pn, port] : oports) { + for (auto &[i, t] : port.impl_->bound_address) { address_set.insert(t); } } - for (auto& [pn, port] :iports) { - for(auto& [i, t] : port.impl_->bound_address){ + for (auto &[pn, port] : iports) { + for (auto &[i, t] : port.impl_->bound_address) { if (address_set.find(t) != address_set.end()) { - std::get<1>(t) = true; + std::get<1>(t) = true; } } } }; -} // namespace ion +} // namespace ion diff --git a/src/port.cc b/src/port.cc index 0df0a1f9..17846e69 100644 --- a/src/port.cc +++ b/src/port.cc @@ -6,17 +6,15 @@ namespace ion { Port::Impl::Impl() - : id(PortID(sole::uuid4().str())), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1) -{ + : id(PortID(sole::uuid4().str())), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1) { } -Port::Impl::Impl(const NodeID & nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID & gid) - : id(PortID(sole::uuid4().str())), pred_chan{nid, pn}, succ_chans{}, type(t), dimensions(d), graph_id(gid) -{ - params[0] = Halide::Parameter(type, dimensions != 0, dimensions, argument_name(nid, id, pn, 0, gid)); +Port::Impl::Impl(const NodeID &nid, const std::string &pn, const Halide::Type &t, int32_t d, const GraphID &gid) + : id(PortID(sole::uuid4().str())), pred_chan{nid, pn}, succ_chans{}, type(t), dimensions(d), graph_id(gid) { + params[0] = Halide::Parameter(type, dimensions != 0, dimensions, argument_name(nid, id, pn, 0, gid)); } -void Port::determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn) { +void Port::determine_succ(const NodeID &nid, const std::string &old_pn, const std::string &new_pn) { auto it = std::find(impl_->succ_chans.begin(), impl_->succ_chans.end(), Channel{nid, old_pn}); if (it == impl_->succ_chans.end()) { log::error("fixme"); @@ -28,7 +26,7 @@ void Port::determine_succ(const NodeID& nid, const std::string& old_pn, const st impl_->succ_chans.insert(Channel{nid, new_pn}); } -std::tuple, bool> Port::find_impl(const std::string& id) { +std::tuple, bool> Port::find_impl(const std::string &id) { static std::unordered_map> impls; static std::mutex mutex; std::scoped_lock lock(mutex); @@ -41,4 +39,4 @@ std::tuple, bool> Port::find_impl(const std::string& return std::make_tuple(impls[id], found); } -} // namespace ion +} // namespace ion diff --git a/src/serializer.h b/src/serializer.h index fd5bfa48..79deae2f 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -11,97 +11,97 @@ #include "log.h" namespace nlohmann { -template <> +template<> struct adl_serializer { - static void to_json(json& j, const halide_type_t& v) { + static void to_json(json &j, const halide_type_t &v) { j["code"] = v.code; j["bits"] = v.bits; j["lanes"] = v.lanes; } - static void from_json(const json& j, halide_type_t& v) { + static void from_json(const json &j, halide_type_t &v) { v.code = j["code"]; v.bits = j["bits"]; v.lanes = j["lanes"]; } }; -template <> +template<> struct adl_serializer { -static void to_json(json& j, const ion::Param& v) { - j["key"] = v.key(); - j["val"] = v.val(); -} + static void to_json(json &j, const ion::Param &v) { + j["key"] = v.key(); + j["val"] = v.val(); + } -static void from_json(const json& j, ion::Param& v) { - v.key() = j["key"].get(); - v.val() = j["val"].get(); -} + static void from_json(const json &j, ion::Param &v) { + v.key() = j["key"].get(); + v.val() = j["val"].get(); + } }; template<> struct adl_serializer { - static void to_json(json& j, const ion::Port& v) { - j["id"] = to_string(v.id()); - std::map stringMap; - j["pred_chan"] = std::make_tuple(to_string(std::get<0>(v.pred_chan())), std::get<1>(v.pred_chan())); - std::set> succ_chans; - for (auto& c:v.succ_chans()){ - succ_chans.insert(std::make_tuple(to_string(std::get<0>(c)), std::get<1>(c))); - } - j["succ_chans"] = succ_chans; - j["type"] = static_cast(v.type()); - j["dimensions"] = v.dimensions(); - j["size"] = v.size(); - j["index"] = v.index(); - } + static void to_json(json &j, const ion::Port &v) { + j["id"] = to_string(v.id()); + std::map stringMap; + j["pred_chan"] = std::make_tuple(to_string(std::get<0>(v.pred_chan())), std::get<1>(v.pred_chan())); + std::set> succ_chans; + for (auto &c : v.succ_chans()) { + succ_chans.insert(std::make_tuple(to_string(std::get<0>(c)), std::get<1>(c))); + } + j["succ_chans"] = succ_chans; + j["type"] = static_cast(v.type()); + j["dimensions"] = v.dimensions(); + j["size"] = v.size(); + j["index"] = v.index(); + } - static void from_json(const json& j, ion::Port& v) { - auto [impl, found] = ion::Port::find_impl(j["id"].get()); - if (!found) { - impl->pred_chan = j["pred_chan"].get>(); - std::set succ_chans; - for (auto & p : j["succ_chans"]){ - succ_chans.insert(p.get>()); - } - impl->succ_chans = succ_chans; - impl->type = j["type"].get(); - impl->dimensions = j["dimensions"]; - for (auto i=0; iparams[i] = Halide::Parameter(impl->type, impl->dimensions != 0, impl->dimensions, - ion::argument_name(std::get<0>(impl->pred_chan), impl->id, std::get<1>(impl->pred_chan), i, impl->graph_id.value())); - } - } - v = ion::Port(impl, j["index"]); - } + static void from_json(const json &j, ion::Port &v) { + auto [impl, found] = ion::Port::find_impl(j["id"].get()); + if (!found) { + impl->pred_chan = j["pred_chan"].get>(); + std::set succ_chans; + for (auto &p : j["succ_chans"]) { + succ_chans.insert(p.get>()); + } + impl->succ_chans = succ_chans; + impl->type = j["type"].get(); + impl->dimensions = j["dimensions"]; + for (auto i = 0; i < j["size"]; ++i) { + impl->params[i] = Halide::Parameter(impl->type, impl->dimensions != 0, impl->dimensions, + ion::argument_name(std::get<0>(impl->pred_chan), impl->id, std::get<1>(impl->pred_chan), i, impl->graph_id.value())); + } + } + v = ion::Port(impl, j["index"]); + } }; -template <> +template<> struct adl_serializer { - static void to_json(json& j, const ion::Node& v) { - j["id"] = to_string(v.id()); - j["name"] = v.name(); - j["target"] = v.target().to_string(); - j["params"] = v.params(); - j["ports"] = v.ports(); - } + static void to_json(json &j, const ion::Node &v) { + j["id"] = to_string(v.id()); + j["name"] = v.name(); + j["target"] = v.target().to_string(); + j["params"] = v.params(); + j["ports"] = v.ports(); + } - static void from_json(const json& j, ion::Node& v) { - auto impl = std::make_shared(); - impl->id = j["id"].get(); - impl->name = j["name"].get(); - impl->target = Halide::Target(j["target"].get()); - impl->params = j["params"].get>(); - impl->ports = j["ports"].get>(); - auto bb(Halide::Internal::GeneratorRegistry::create(impl->name, Halide::GeneratorContext(impl->target))); - if (!bb) { - ion::log::error("BuildingBlock {} is not found", impl->name); - throw std::runtime_error("Failed to create building block object"); - } - impl->arginfos = bb->arginfos(); - v = ion::Node(impl); - } + static void from_json(const json &j, ion::Node &v) { + auto impl = std::make_shared(); + impl->id = j["id"].get(); + impl->name = j["name"].get(); + impl->target = Halide::Target(j["target"].get()); + impl->params = j["params"].get>(); + impl->ports = j["ports"].get>(); + auto bb(Halide::Internal::GeneratorRegistry::create(impl->name, Halide::GeneratorContext(impl->target))); + if (!bb) { + ion::log::error("BuildingBlock {} is not found", impl->name); + throw std::runtime_error("Failed to create building block object"); + } + impl->arginfos = bb->arginfos(); + v = ion::Node(impl); + } }; -} +} // namespace nlohmann #endif diff --git a/src/target.cc b/src/target.cc index 4c02d4f5..b66b0e9f 100644 --- a/src/target.cc +++ b/src/target.cc @@ -13,5 +13,4 @@ Target get_target_from_environment() { return Halide::get_target_from_environment(); } -} // ion - +} // namespace ion diff --git a/src/util.cc b/src/util.cc index 3314154a..a8c59f8e 100644 --- a/src/util.cc +++ b/src/util.cc @@ -5,19 +5,18 @@ namespace ion { -std::string argument_name(const NodeID & node_id, const PortID & portId, const std::string& name, int32_t index, const GraphID & graph_id) { +std::string argument_name(const NodeID &node_id, const PortID &portId, const std::string &name, int32_t index, const GraphID &graph_id) { if (index == -1) { index = 0; } - std::string s = "_" + node_id.value() + "_" + portId.value() + "_" + name + std::to_string(index) + "_" + graph_id.value(); + std::string s = "_" + node_id.value() + "_" + portId.value() + "_" + name + std::to_string(index) + "_" + graph_id.value(); std::replace(s.begin(), s.end(), '-', '_'); return s; } -std::string array_name(const std::string& port_name, size_t i) { +std::string array_name(const std::string &port_name, size_t i) { return port_name + "_" + std::to_string(i); } -} // namespace ion - +} // namespace ion diff --git a/test/array_dup_names.cc b/test/array_dup_names.cc index 08d8372a..f862b94f 100644 --- a/test/array_dup_names.cc +++ b/test/array_dup_names.cc @@ -49,10 +49,10 @@ int main() { } } } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/array_inout.cc b/test/array_inout.cc index 2a2ee474..680cb760 100644 --- a/test/array_inout.cc +++ b/test/array_inout.cc @@ -47,10 +47,10 @@ int main() { } } } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/array_input.cc b/test/array_input.cc index 51a9e962..f348c45e 100644 --- a/test/array_input.cc +++ b/test/array_input.cc @@ -26,8 +26,7 @@ int main() { Halide::Buffer{w, h}, Halide::Buffer{w, h}, Halide::Buffer{w, h}, - Halide::Buffer{w, h} - }; + Halide::Buffer{w, h}}; Halide::Buffer out(w, h); @@ -40,7 +39,7 @@ int main() { } } - for (int i=0; i(); Port input{"input", t, 2}, width{"width", t}, height{"height", t}; @@ -25,10 +24,10 @@ int main() ln = b.add("test_inc_i32x2")(n["output0"]).set_params(v0); rn = b.add("test_inc_i32x2")(n["output1"]).set_params(v0); b.compile("complex_graph"); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/complex_graph_jit.cc b/test/complex_graph_jit.cc index ff745225..131eb7e5 100644 --- a/test/complex_graph_jit.cc +++ b/test/complex_graph_jit.cc @@ -2,14 +2,13 @@ using namespace ion; -int main() -{ +int main() { try { int32_t size = 16; ion::Buffer in(std::vector{size, size}); - for (int y=0; y in(std::vector{size, size}); - for (int32_t y=0; y out0(std::vector{size, size/split_n}); - Buffer out1(std::vector{size, size/split_n}); - for (int32_t y=0; y out0(std::vector{size, size / split_n}); + Buffer out1(std::vector{size, size / split_n}); + for (int32_t y = 0; y < size / split_n; ++y) { + for (int32_t x = 0; x < size; ++x) { out0(x, y) = 0; out1(x, y) = 0; } } int ret = complex_graph(in, size, size, out0, out1); - for (int y=0; y { GeneratorParam num{"num", 0}; void configure() { - for (int32_t i=0; i("extra_scalar_input_" + std::to_string(i))); } } @@ -18,7 +18,7 @@ struct Test : BuildingBlock { void generate() { Halide::Var i; Halide::Expr v = input(i); - for (int i=0; i output{size}; b.add("test")(input)["output"].bind(output); b.run(); - for (int i=0; i { ion::Input input{"input0", Int(32), 2}; ion::Output output{"output0", Int(32), 2}; @@ -109,13 +108,13 @@ int main() { CUdevice device; CU_SAFE_CALL(cuDeviceGet(&device, 0)); - + CU_SAFE_CALL(cuCtxCreate(reinterpret_cast(&state.cuda_context), 0, device)); std::cout << "CUcontext is created on application side : " << state.cuda_context << std::endl; - + CU_SAFE_CALL(cuCtxSetCurrent(reinterpret_cast(state.cuda_context))); - + // CUstream is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html) CU_SAFE_CALL(cuStreamCreate(reinterpret_cast(&state.cuda_stream), CU_STREAM_DEFAULT)); @@ -166,7 +165,7 @@ int main() { // CUDA cleanup CUDA_SAFE_CALL(cudaFree(src)); CUDA_SAFE_CALL(cudaFree(dst)); - + CU_SAFE_CALL(cuStreamDestroy(reinterpret_cast(state.cuda_stream))); CU_SAFE_CALL(cuCtxDestroy(reinterpret_cast(state.cuda_context))); diff --git a/test/direct-extern.cc b/test/direct-extern.cc index 9e64cad6..adbf4890 100644 --- a/test/direct-extern.cc +++ b/test/direct-extern.cc @@ -2,8 +2,7 @@ using namespace ion; -int main() -{ +int main() { try { int size = 32; @@ -13,15 +12,15 @@ int main() Halide::Buffer ibuf(std::vector{size, size}); Halide::Buffer obuf(std::vector{size, size}); - for (int y=0; y in(std::vector{width, height, 3}); ion::Buffer r = ion::Buffer::make_scalar(); - for (int y=0; y(); Port input{"input", t, 2}, width{"width", t}, height{"height", t}; @@ -18,10 +17,10 @@ int main() n = b.add("test_merge")(n["output0"], n["output1"], height); b.compile("complex_graph"); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/dup.cc b/test/dup.cc index 6518f842..86c4071e 100644 --- a/test/dup.cc +++ b/test/dup.cc @@ -5,8 +5,7 @@ using namespace ion; -int main() -{ +int main() { try { Param v41("v", 41); @@ -24,7 +23,7 @@ int main() outBuf1(0, 0) = 1; intm.bind(outBuf1); - for (int i=0; i<10; ++i) { + for (int i = 0; i < 10; ++i) { b.run(); if (outBuf0(0, 0) != outBuf1(0, 0)) { std::cout << "o0:" << outBuf0(0, 0) << std::endl; @@ -33,10 +32,10 @@ int main() } } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/error.cc b/test/error.cc index 75c3832e..201135e9 100644 --- a/test/error.cc +++ b/test/error.cc @@ -6,8 +6,7 @@ using namespace std; using namespace ion; -int main() -{ +int main() { try { Builder b; b.add("sonzai_shinai_bb"); diff --git a/test/export.cc b/test/export.cc index e23b2868..bb75d664 100644 --- a/test/export.cc +++ b/test/export.cc @@ -2,181 +2,179 @@ using namespace ion; -int main() -{ +int main() { try { // simple_graph { - { - ion::Type t = ion::type_of(); - Port min0{"min0", t}, extent0{"extent0", t}, min1{"min1", t}, extent1{"extent1", t}, v{"v", t}; - Param v41("v", 41); - Builder b; - b.with_bb_module("ion-bb-test"); - b.set_target(get_host_target()); - Node n; - n = b.add("test_producer").set_params(v41); - n = b.add("test_consumer")(n["output"], min0, extent0, min1, extent1, v); - b.save("simple_graph.json"); - } - { - int32_t min0 = 0, extent0 = 2, min1 = 0, extent1 = 2, v = 1; + {ion::Type t = ion::type_of(); + Port min0{"min0", t}, extent0{"extent0", t}, min1{"min1", t}, extent1{"extent1", t}, v{"v", t}; + Param v41("v", 41); + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(get_host_target()); + Node n; + n = b.add("test_producer").set_params(v41); + n = b.add("test_consumer")(n["output"], min0, extent0, min1, extent1, v); + b.save("simple_graph.json"); + } + { + int32_t min0 = 0, extent0 = 2, min1 = 0, extent1 = 2, v = 1; - ion::Type t = ion::type_of(); - Builder b; - b.with_bb_module("ion-bb-test"); - b.load("simple_graph.json"); + ion::Type t = ion::type_of(); + Builder b; + b.with_bb_module("ion-bb-test"); + b.load("simple_graph.json"); - auto out = ion::Buffer::make_scalar(); + auto out = ion::Buffer::make_scalar(); - auto nodes = b.nodes(); - nodes[1].iport("min0").bind(&min0); - nodes[1].iport("extent0").bind(&extent0); - nodes[1].iport("min1").bind(&min1); - nodes[1].iport("extent1").bind(&extent1); - nodes[1].iport("v").bind(&v); - nodes[1].oport("output").bind(out); + auto nodes = b.nodes(); + nodes[1].iport("min0").bind(&min0); + nodes[1].iport("extent0").bind(&extent0); + nodes[1].iport("min1").bind(&min1); + nodes[1].iport("extent1").bind(&extent1); + nodes[1].iport("v").bind(&v); + nodes[1].oport("output").bind(out); - b.run(); - } - } + b.run(); + } +} - // complex_graph - { - ion::Type t = ion::type_of(); - Port input{"input", t, 2}, width{"width", t}, height{"height", t}; - Param v1("v", 1); - Builder b; - b.with_bb_module("ion-bb-test"); - b.set_target(ion::get_host_target()); - Node n; - n = b.add("test_inc_i32x2")(input).set_params(v1); - n = b.add("test_branch")(n["output"], width, height); - auto ln = b.add("test_inc_i32x2")(n["output0"]); - auto rn = b.add("test_inc_i32x2")(n["output1"]).set_params(v1); - n = b.add("test_merge")(ln["output"], rn["output"], height); - b.save("complex_graph.json"); - } +// complex_graph +{ + ion::Type t = ion::type_of(); + Port input{"input", t, 2}, width{"width", t}, height{"height", t}; + Param v1("v", 1); + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(ion::get_host_target()); + Node n; + n = b.add("test_inc_i32x2")(input).set_params(v1); + n = b.add("test_branch")(n["output"], width, height); + auto ln = b.add("test_inc_i32x2")(n["output0"]); + auto rn = b.add("test_inc_i32x2")(n["output1"]).set_params(v1); + n = b.add("test_merge")(ln["output"], rn["output"], height); + b.save("complex_graph.json"); +} - { - ion::Type t = ion::type_of(); - Builder b; - b.with_bb_module("ion-bb-test"); - b.load("complex_graph.json"); - b.set_target(ion::get_host_target()); - - int32_t size = 16; - int32_t split_n = 2; - - ion::Buffer in(std::vector{size, size}); - for (int y=0; y(); + Builder b; + b.with_bb_module("ion-bb-test"); + b.load("complex_graph.json"); + b.set_target(ion::get_host_target()); + + int32_t size = 16; + int32_t split_n = 2; + + ion::Buffer in(std::vector{size, size}); + for (int y = 0; y < size; ++y) { + for (int x = 0; x < size; ++x) { + in(x, y) = 40; + } + } - ion::Buffer out(std::vector{size, size}); + ion::Buffer out(std::vector{size, size}); - auto nodes = b.nodes(); + auto nodes = b.nodes(); - nodes[0].iport("input").bind(in); - nodes[1].iport("input_width").bind(&size); - nodes[1].iport("input_height").bind(&size); - nodes[4].iport("output_height").bind(&size); - nodes[4].oport("output").bind(out); + nodes[0].iport("input").bind(in); + nodes[1].iport("input_width").bind(&size); + nodes[1].iport("input_height").bind(&size); + nodes[4].iport("output_height").bind(&size); + nodes[4].oport("output").bind(out); - b.compile("ex"); - b.run(); + b.compile("ex"); + b.run(); - int y=0; - for (; y(), 2}; - Builder b; - b.with_bb_module("ion-bb-test"); - b.set_target(ion::get_host_target()); - auto n = b.add("test_array_output")(input).set_params(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]).set_params(Param("array_input.size", len)); - b.save("array_inout.json"); +// Array inout +{ + constexpr size_t h = 10, w = 10, len = 5; + { + Port input{"input", ion::type_of(), 2}; + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(ion::get_host_target()); + auto n = b.add("test_array_output")(input).set_params(Param("len", len)); + n = b.add("test_array_input")(n["array_output"]).set_params(Param("array_input.size", len)); + b.save("array_inout.json"); + } + { + Port input{"input", ion::type_of(), 2}; + Builder b; + b.with_bb_module("ion-bb-test"); + b.load("array_inout.json"); + + ion::Buffer in(w, h); + for (int y = 0; y < h; ++y) { + for (int x = 0; x < w; ++x) { + in(x, y) = y * w + x; } - { - Port input{"input", ion::type_of(), 2}; - Builder b; - b.with_bb_module("ion-bb-test"); - b.load("array_inout.json"); - - ion::Buffer in(w, h); - for (int y = 0; y < h; ++y) { - for (int x = 0; x < w; ++x) { - in(x, y) = y * w + x; - } - } + } - ion::Buffer out(w, h); + ion::Buffer out(w, h); - for (auto& n : b.nodes()) { - if (n.name() == "test_array_output") { - n.iport("input").bind(in); - } else if (n.name() == "test_array_input") { - n.oport("output").bind(out); - } - } + for (auto &n : b.nodes()) { + if (n.name() == "test_array_output") { + n.iport("input").bind(in); + } else if (n.name() == "test_array_input") { + n.oport("output").bind(out); + } + } - b.run(); + b.run(); - if (out.dimensions() != 2) { - return 1; - } - if (out.extent(0) != h) { - return 1; - } - if (out.extent(1) != w) { - return 1; - } + if (out.dimensions() != 2) { + return 1; + } + if (out.extent(0) != h) { + return 1; + } + if (out.extent(1) != w) { + return 1; + } - for (int y = 0; y < h; ++y) { - for (int x = 0; x < w; ++x) { - if (len * in(x, y) != out(x, y)) { - return 1; - } - } + for (int y = 0; y < h; ++y) { + for (int x = 0; x < w; ++x) { + if (len * in(x, y) != out(x, y)) { + return 1; } - } } - } catch (const Halide::Error &e) { - std::cerr << e.what() << std::endl; - return 1; - } catch (const std::exception &e) { - std::cerr << e.what() << std::endl; - return 1; } +} +} +catch (const Halide::Error &e) { + std::cerr << e.what() << std::endl; + return 1; +} +catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; +} - std::cout << "Passed" << std::endl; +std::cout << "Passed" << std::endl; - return 0; +return 0; } diff --git a/test/gpu-extern.cc b/test/gpu-extern.cc index f4e4a8c1..83cfbf6a 100644 --- a/test/gpu-extern.cc +++ b/test/gpu-extern.cc @@ -5,8 +5,7 @@ using namespace ion; -int main() -{ +int main() { try { int size = 32; @@ -23,27 +22,25 @@ int main() n = b.add("test_extern_inc_i32x2")(ip).set_params(wp, hp, vp); n = b.add("test_extern_inc_i32x2")(n["output"]).set_params(wp, hp, vp); - - Halide::Buffer ibuf(std::vector{size, size}); - for (int y=0; y obuf(std::vector{size, size}); - for (int y=0; y(), 2}; @@ -25,10 +24,10 @@ int main() b.compile("gpu_extern"); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/gpu-extern_run.cc b/test/gpu-extern_run.cc index 64c72ddf..4fd69e0d 100644 --- a/test/gpu-extern_run.cc +++ b/test/gpu-extern_run.cc @@ -6,29 +6,28 @@ using namespace Halide::Runtime; -int main() -{ +int main() { try { int size = 32; Buffer ibuf(std::vector{size, size}); - for (int y=0; y obuf(std::vector{size, size}); - for (int y=0; y ibuf0(std::vector{1, 1}); - Port ip0{"input", Halide::type_of(), 2}; Port vp0{"v", Halide::type_of()}; @@ -135,7 +132,6 @@ int main() obuf1.fill(0); n1["output"].bind(obuf1); - ibuf1(0, 0) = 42; v1 = 1; obuf1(0, 0) = 0; @@ -172,5 +168,4 @@ int main() } std::cout << "All Passed" << std::endl; return 0; - } diff --git a/test/inverted_dep.cc b/test/inverted_dep.cc index fe76a7e2..a16687c8 100644 --- a/test/inverted_dep.cc +++ b/test/inverted_dep.cc @@ -7,8 +7,7 @@ using namespace ion; -int main() -{ +int main() { try { const char *file_name = "test.graph"; @@ -200,7 +199,7 @@ int main() b.set_target(Halide::get_host_target()); b.load(file_name); - for (auto& n : b.nodes()) { + for (auto &n : b.nodes()) { std::cout << n.name() << std::endl; if (n.name() == "test_consumer") { n.iport("min0").bind(&min0); diff --git a/test/ion-bb-test.cc b/test/ion-bb-test.cc index 40fe422d..c8beb693 100644 --- a/test/ion-bb-test.cc +++ b/test/ion-bb-test.cc @@ -5,31 +5,28 @@ #include "dynamic_module.h" #include "log.h" -extern "C" -void register_externs(std::map& externs) { +extern "C" void register_externs(std::map &externs) { externs.insert({"consume", Halide::JITExtern(consume)}); externs.insert({"branch", Halide::JITExtern(branch)}); externs.insert({"inc", Halide::JITExtern(inc)}); } -extern "C" -int consume_dispose(const char *id) { +extern "C" int consume_dispose(const char *id) { ion::log::info("consume_dispose is called with id={}", id); return 0; } -extern "C" -int consume(halide_buffer_t *in, halide_buffer_t *id_buf, int desired_min0, int desired_extent0, int desired_min1, int desired_extent1, int32_t v, halide_buffer_t *out) { +extern "C" int consume(halide_buffer_t *in, halide_buffer_t *id_buf, int desired_min0, int desired_extent0, int desired_min1, int desired_extent1, int32_t v, halide_buffer_t *out) { if (in->is_bounds_query()) { in->dim[0].min = desired_min0; in->dim[0].extent = desired_extent0; in->dim[1].min = desired_min1; in->dim[1].extent = desired_extent1; } else { - ion::log::info("consume is called with id={}", reinterpret_cast(id_buf->host)); + ion::log::info("consume is called with id={}", reinterpret_cast(id_buf->host)); Halide::Runtime::Buffer ibuf(*in); - for (int y=0; ydim[1].extent; ++y) { - for (int x=0; xdim[0].extent; ++x) { + for (int y = 0; y < in->dim[1].extent; ++y) { + for (int x = 0; x < in->dim[0].extent; ++x) { std::cout << ibuf(x, y) + v << " "; } std::cout << std::endl; @@ -39,20 +36,19 @@ int consume(halide_buffer_t *in, halide_buffer_t *id_buf, int desired_min0, int return 0; } -extern "C" -int branch(halide_buffer_t *in, int32_t input_width, int32_t input_height, halide_buffer_t *out0, halide_buffer_t *out1) { +extern "C" int branch(halide_buffer_t *in, int32_t input_width, int32_t input_height, halide_buffer_t *out0, halide_buffer_t *out1) { if (in->is_bounds_query() || out0->is_bounds_query() || out1->is_bounds_query()) { if (out0->is_bounds_query()) { out0->dim[0].min = 0; out0->dim[0].extent = input_width; out0->dim[1].min = 0; - out0->dim[1].extent = input_height/2; + out0->dim[1].extent = input_height / 2; } if (out1->is_bounds_query()) { out1->dim[0].min = 0; out1->dim[0].extent = input_width; out1->dim[1].min = 0; - out1->dim[1].extent = input_height/2; + out1->dim[1].extent = input_height / 2; } if (in->is_bounds_query()) { in->dim[0].min = 0; @@ -64,10 +60,10 @@ int branch(halide_buffer_t *in, int32_t input_width, int32_t input_height, halid Halide::Runtime::Buffer ibuf(*in); Halide::Runtime::Buffer obuf0(*out0); Halide::Runtime::Buffer obuf1(*out1); - for (int y=0; yis_bounds_query()) { @@ -109,8 +104,8 @@ int inc(halide_buffer_t *in, int32_t width, int32_t height, int32_t v, bool use_ static ion::DynamicModule dm("gpu-extern-lib"); call_inc_kernel_t call_inc_kernel = dm.get_symbol("call_inc_kernel"); - call_inc_kernel(reinterpret_cast(ibuf.raw_buffer()->device), obuf.extent(0), obuf.extent(1), v, - reinterpret_cast(obuf.raw_buffer()->device)); + call_inc_kernel(reinterpret_cast(ibuf.raw_buffer()->device), obuf.extent(0), obuf.extent(1), v, + reinterpret_cast(obuf.raw_buffer()->device)); if (copy_to_host) { obuf.set_host_dirty(false); @@ -118,8 +113,8 @@ int inc(halide_buffer_t *in, int32_t width, int32_t height, int32_t v, bool use_ obuf.copy_to_host(); } } else { - for (int y=obuf.min(1); y in(std::vector{nx, ny}); - for (int y=0; y o0(std::vector{nx}); ion::Buffer o1(std::vector{nx, ny}); ion::Buffer o2(std::vector{nx, ny, nc}); - for (int c=0; c in(std::vector{nx, ny}); - for (int y=0; y o0(std::vector{nx2, ny2}); - for (int y=0; y o1(std::vector{nx, ny}); - for (int y=0; y in0(std::vector{size, size}); in0.fill(0); Buffer out0(std::vector{size, size}); - n = b.add("test_inc_i32x2")(in0).set_params(Param{"v", 1});; + n = b.add("test_inc_i32x2")(in0).set_params(Param{"v", 1}); + ; n["output"].bind(out0); Buffer in1(std::vector{size, size}); @@ -28,8 +27,8 @@ int main() b.run(); - for (int y=0; y #include - using namespace ion; void display_image_float(Halide::Buffer buffer, std::string filename) { @@ -17,9 +16,9 @@ void display_image_float(Halide::Buffer buffer, std::string filename) { if (channels == 3) { cv::Mat img_float; cv::merge(std::vector{ - cv::Mat(height, width, CV_32F, buffer.data() + width * height * 2), - cv::Mat(height, width, CV_32F, buffer.data() + width * height * 1), - cv::Mat(height, width, CV_32F, buffer.data())}, + cv::Mat(height, width, CV_32F, buffer.data() + width * height * 2), + cv::Mat(height, width, CV_32F, buffer.data() + width * height * 1), + cv::Mat(height, width, CV_32F, buffer.data())}, img_float); img_float.convertTo(img_out, CV_8U, 255); } else { @@ -27,12 +26,11 @@ void display_image_float(Halide::Buffer buffer, std::string filename) { img_float.convertTo(img_out, CV_8U, 255); } #ifdef DISPLAY - cv::imshow( "Display window: " + filename, img_out); + cv::imshow("Display window: " + filename, img_out); cv::waitKey(3000); #endif } - int main(int argc, char *argv[]) { try { int width = 200; @@ -49,20 +47,12 @@ int main(int argc, char *argv[]) { b.with_bb_module("ion-bb"); Node n; - n = b.add("image_io_cameraN").set_params( - wparam, - hparam, - Param("num_devices", 2), - Param("urls", "http://optipng.sourceforge.net/pngtech/img/lena.png;http://upload.wikimedia.org/wikipedia/commons/0/05/Cat.png") - ); + n = b.add("image_io_cameraN").set_params(wparam, hparam, Param("num_devices", 2), Param("urls", "http://optipng.sourceforge.net/pngtech/img/lena.png;http://upload.wikimedia.org/wikipedia/commons/0/05/Cat.png")); n = b.add("base_normalize_3d_uint8")(n["output"][1]); // access only port[1] - n = b.add("image_processing_resize_nearest_3d")(n["output"]).set_params( - Param("width", width), - Param("height", height), - Param("scale", 2)); + n = b.add("image_processing_resize_nearest_3d")(n["output"]).set_params(Param("width", width), Param("height", height), Param("scale", 2)); Port output = n["output"]; - Halide::Buffer out_buf( width, height,3); + Halide::Buffer out_buf(width, height, 3); output.bind(out_buf); b.run(); diff --git a/test/port-assign.cc b/test/port-assign.cc index 176dfd24..c9c11a55 100644 --- a/test/port-assign.cc +++ b/test/port-assign.cc @@ -5,8 +5,7 @@ using namespace ion; -int main() -{ +int main() { try { Builder b; b.set_target(Halide::get_host_target()); @@ -40,10 +39,10 @@ int main() return 1; } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/port-binding.cc b/test/port-binding.cc index e5ad982b..9ac801ab 100644 --- a/test/port-binding.cc +++ b/test/port-binding.cc @@ -5,8 +5,7 @@ using namespace ion; -int main() -{ +int main() { try { Builder b; b.set_target(Halide::get_host_target()); @@ -58,10 +57,10 @@ int main() return 1; } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/simple_graph_compile.cc b/test/simple_graph_compile.cc index b92c0999..d1cd8835 100644 --- a/test/simple_graph_compile.cc +++ b/test/simple_graph_compile.cc @@ -4,8 +4,7 @@ using namespace ion; -int main() -{ +int main() { try { Halide::Type t = Halide::type_of(); Port min0{"min0", t}, extent0{"extent0", t}, min1{"min1", t}, extent1{"extent1", t}, v{"v", t}; @@ -16,10 +15,10 @@ int main() n = b.add("test_producer").set_params(v41); n = b.add("test_consumer")(n["output"], min0, extent0, min1, extent1, v); b.compile("simple_graph"); - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/simple_graph_jit.cc b/test/simple_graph_jit.cc index a1a21a5c..42d7c71b 100644 --- a/test/simple_graph_jit.cc +++ b/test/simple_graph_jit.cc @@ -2,8 +2,7 @@ using namespace ion; -int main() -{ +int main() { try { // New API int32_t min0 = 0, extent0 = 2, min1 = 0, extent1 = 2, v = 1; @@ -20,15 +19,15 @@ int main() b.save("simple_graph.graph"); - for (int i=0; i<5; ++i) { + for (int i = 0; i < 5; ++i) { std::cout << i << "'th loop" << std::endl; b.run(); } - } catch (Halide::Error& e) { + } catch (Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/simple_graph_run.cc b/test/simple_graph_run.cc index 226f4e30..77736b0a 100644 --- a/test/simple_graph_run.cc +++ b/test/simple_graph_run.cc @@ -5,8 +5,7 @@ using namespace Halide::Runtime; -int main() -{ +int main() { Buffer out = Buffer::make_scalar(); return simple_graph(2, 2, 0, 0, 1, out); } diff --git a/test/test-bb.h b/test/test-bb.h index ffbb341a..36b72998 100644 --- a/test/test-bb.h +++ b/test/test-bb.h @@ -19,6 +19,7 @@ class Producer : public BuildingBlock { void schedule() { output.compute_root(); } + private: Halide::Var x, y; }; @@ -49,6 +50,7 @@ class Consumer : public BuildingBlock { void schedule() { } + private: Halide::Var x, y; }; @@ -91,7 +93,7 @@ class Merge : public BuildingBlock { Output output{"output", Int(32), 2}; void generate() { - output(x, y) = select(y < Halide::cast(output_height)/2, input0(x, y), input1(x, clamp(y - output_height, 0, output_height))); + output(x, y) = select(y < Halide::cast(output_height) / 2, input0(x, y), input1(x, clamp(y - output_height, 0, output_height))); } void schedule() { @@ -130,7 +132,7 @@ class Inc : public BuildingBlock> { private: Halide::Var x, y; }; -using IncI32x2 = Inc; +using IncI32x2 = Inc; template class IncX : public BuildingBlock> { @@ -149,7 +151,7 @@ class IncX : public BuildingBlock> { private: Halide::Var x, y; }; -using IncXI32x2 = IncX; +using IncXI32x2 = IncX; class Dup : public BuildingBlock { public: @@ -172,7 +174,7 @@ class Scale2x : public BuildingBlock { Output output{"output", Int(32), 2}; void generate() { - output(x, y) = input(x/2, y/2); + output(x, y) = input(x / 2, y / 2); } private: @@ -204,7 +206,7 @@ class ArrayInput : public BuildingBlock { void generate() { Halide::Expr v = 0; for (int i = 0; i < array_input.size(); ++i) { - v += array_input[i](x, y); + v += array_input[i](x, y); } output(x, y) = v; } @@ -246,7 +248,6 @@ class ArrayCopy : public BuildingBlock { Halide::Var x, y; }; - class ExternIncI32x2 : public BuildingBlock { public: BuildingBlockParam v{"v", 0}; @@ -297,7 +298,7 @@ class SubI32x2 : public BuildingBlock { class IncByOffset : public BuildingBlock { public: Input input{"input", Int(32), 2}; - Input input_offset{"input_offset", Int(32), 0}; // to imitate scalar input + Input input_offset{"input_offset", Int(32), 0}; // to imitate scalar input BuildingBlockParam v{"v", 1}; Output output{"output", Int(32), 2}; Output output_offset{"output_offset"}; @@ -311,10 +312,9 @@ class IncByOffset : public BuildingBlock { Halide::Var x, y; }; - -} // test -} // bb -} // ion +} // namespace test +} // namespace bb +} // namespace ion ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Producer, test_producer); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Consumer, test_consumer); diff --git a/test/test-rt.h b/test/test-rt.h index 9a2062bd..4b5bb906 100644 --- a/test/test-rt.h +++ b/test/test-rt.h @@ -10,17 +10,13 @@ #define DLLEXPORT #endif -extern "C" DLLEXPORT -int consume_dispose(const char *id); +extern "C" DLLEXPORT int consume_dispose(const char *id); -extern "C" DLLEXPORT -int consume(halide_buffer_t *in, halide_buffer_t *id_buf, int desired_min0, int desired_extent0, int desired_min1, int desired_extent1, int32_t v, halide_buffer_t *out); +extern "C" DLLEXPORT int consume(halide_buffer_t *in, halide_buffer_t *id_buf, int desired_min0, int desired_extent0, int desired_min1, int desired_extent1, int32_t v, halide_buffer_t *out); -extern "C" DLLEXPORT -int branch(halide_buffer_t *in, int32_t input_width, int32_t input_height, halide_buffer_t *out0, halide_buffer_t *out1); +extern "C" DLLEXPORT int branch(halide_buffer_t *in, int32_t input_width, int32_t input_height, halide_buffer_t *out0, halide_buffer_t *out1); -extern "C" DLLEXPORT -int inc(halide_buffer_t *in, int32_t width, int32_t height, int32_t v, bool use_gpu, halide_buffer_t *out); +extern "C" DLLEXPORT int inc(halide_buffer_t *in, int32_t width, int32_t height, int32_t v, bool use_gpu, halide_buffer_t *out); #undef DLLEXPORT diff --git a/test/unbound_binding.cc b/test/unbound_binding.cc index e27a46ad..430494f4 100644 --- a/test/unbound_binding.cc +++ b/test/unbound_binding.cc @@ -34,40 +34,39 @@ int main() { Halide::Buffer param_buf = Halide::Buffer::make_scalar(); param_buf.fill(1); -// std::vector< int > sizes; -// Halide::Buffer param_buf1(param_buf.data(),sizes); + // std::vector< int > sizes; + // Halide::Buffer param_buf1(param_buf.data(),sizes); - for(auto &n:b.nodes()){ - for (auto& [pn, port] : n.unbound_iports()) { - port.bind(param_buf); - n.set_iport(port); + for (auto &n : b.nodes()) { + for (auto &[pn, port] : n.unbound_iports()) { + port.bind(param_buf); + n.set_iport(port); } - for (auto& [pn, port] : n.unbound_oports()) { + for (auto &[pn, port] : n.unbound_oports()) { port.bind(param_buf); n.set_oport(port); } - } - b.run(); - for (int y = 0; y < h; ++y) { - for (int x = 0; x < w; ++x) { - if (out(0,0) != 45 ) { + b.run(); + for (int y = 0; y < h; ++y) { + for (int x = 0; x < w; ++x) { + if (out(0, 0) != 45) { throw runtime_error("Unexpected out value"); - } + } } } - if (param_buf(0) != 3 ) { - throw runtime_error("Unexpected value"); - } + if (param_buf(0) != 3) { + throw runtime_error("Unexpected value"); + } } - } catch (const Halide::Error& e) { + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } diff --git a/test/validation.cc b/test/validation.cc index b59ee471..5f000969 100644 --- a/test/validation.cc +++ b/test/validation.cc @@ -7,8 +7,7 @@ using namespace ion; -int main() -{ +int main() { try { Buffer input(2, 2); Buffer output(2, 2); @@ -25,7 +24,7 @@ int main() try { b.run(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { // The error should thrown as runtime_error, not Halide::Error std::cerr << e.what() << std::endl; } @@ -43,7 +42,7 @@ int main() try { b.run(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { // The error should thrown as runtime_error, not Halide::Error std::cerr << e.what() << std::endl; } @@ -61,7 +60,7 @@ int main() try { b.run(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { // The error should thrown as runtime_error, not Halide::Error std::cerr << e.what() << std::endl; } @@ -82,16 +81,16 @@ int main() try { b.run(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { // The error should thrown as runtime_error, not Halide::Error std::cerr << e.what() << std::endl; } } - } catch (Halide::Error& e) { + } catch (Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; - } catch (const std::exception& e) { + } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 0; }